提交 9bebdcb5 编写于 作者: J jingqinghe

update lite code

......@@ -2193,7 +2193,13 @@ void pooling3x3s2p1_max(const float* din,
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
float32x4_t vmin = vdupq_n_f32(std::numeric_limits<float>::lowest());
int w_needed = wout * 2 + 1;
int pad_right_ = w_needed - win - pad_bottom;
int w_2 = pad_right_ > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmin = vdupq_n_f32(minval);
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
......@@ -2232,6 +2238,11 @@ void pooling3x3s2p1_max(const float* din,
break;
}
}
auto pr0 = dr0;
auto pr1 = dr1;
auto pr2 = dr2;
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
......@@ -2285,27 +2296,53 @@ void pooling3x3s2p1_max(const float* din,
"q11",
"q15");
#endif
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
} else {
float tmp = minval;
for (int i = 0; i < 2; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
dr2 += S - (st - wstart);
wstart += S;
dr_out[0] = tmp;
dr0++;
dr1++;
dr2++;
dr_out++;
}
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (pad_right_) {
float tmp = minval;
for (int i = 1; i < 3; i++) {
tmp = std::max(tmp, std::max(pr0[win - i], pr1[win - i]));
tmp = std::max(tmp, pr2[win - i]);
}
dr_out[0] = tmp;
}
data_out_channel += wout;
}
}
......@@ -2539,6 +2576,10 @@ void pooling3x3s2p0_max(const float* din,
int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad
int w_2 = right > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
float minval = std::numeric_limits<float>::lowest();
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
......@@ -2592,18 +2633,24 @@ void pooling3x3s2p0_max(const float* din,
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
wstart += S;
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
#else
asm volatile(
......
......@@ -79,6 +79,13 @@ void slice(const Dtype* input,
}
}
template void slice(const float* input,
std::vector<int64_t> dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out,
Context<TARGET(kARM)>* ctx);
template void slice(const int* input,
std::vector<int64_t> dims,
std::vector<int> axes,
......@@ -86,12 +93,12 @@ template void slice(const int* input,
std::vector<int> ends,
int* out,
Context<TARGET(kARM)>* ctx);
template void slice(const float* input,
template void slice(const int64_t* input,
std::vector<int64_t> dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out,
int64_t* out,
Context<TARGET(kARM)>* ctx);
} // namespace math
......
......@@ -51,11 +51,11 @@ void split_cpy<float>(const float* din, float* dout, int num) {
}
}
template <>
void split<float>(const float* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides) {
template <typename T>
void split(const T* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides) {
int input_offset = 0;
for (auto out : dout) {
auto out_dim = out->dims();
......@@ -65,15 +65,15 @@ void split<float>(const float* din,
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
float* out_data = out->mutable_data<float>();
T* out_data = out->mutable_data<T>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
const float* din_ptr = din + input_offset;
const T* din_ptr = din + input_offset;
for (int i = 0; i < before; ++i) {
std::memcpy(out_data, din_ptr, sizeof(float) * out_after);
std::memcpy(out_data, din_ptr, sizeof(T) * out_after);
din_ptr += in_after;
out_data += out_after;
}
......@@ -81,6 +81,15 @@ void split<float>(const float* din,
}
}
template void split(const float* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
template void split(const int64_t* din,
const std::vector<lite::Tensor*>& dout,
const int axis,
const std::vector<int>& in_strides);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -20,6 +20,7 @@ nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps})
nv_library(cuda_sequence_helper SRCS sequence_helper.cu DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -39,6 +40,7 @@ set (
cuda_sequence_padding
cuda_bias
cudnn_helper
cuda_sequence_helper
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
......@@ -55,31 +55,32 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NCHW,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
CUDNN_TENSOR_NCHW,
oc,
ic / param.groups,
kh,
kw));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
GetCudnnDataType<Ptype_out>()));
CUDNN_CHECK(
cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
cudnn::cudnnTypeWrapper<T>::type));
CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NCHW,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
batch,
oc,
oh,
......@@ -179,7 +180,7 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
int dim_bias[] = {1, oc, 1, 1};
int stride_bias[] = {oc, 1, 1, 1};
cudnnSetTensorNdDescriptor(this->bias_desc_,
GetCudnnDataType<Ptype_out>(),
cudnn::cudnnTypeWrapper<T>::type,
4,
dim_bias,
stride_bias);
......
......@@ -21,17 +21,7 @@ namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <>
cudnnDataType_t GetCudnnDataType<PRECISION(kFloat)>() {
return CUDNN_DATA_FLOAT;
}
template <>
cudnnDataType_t GetCudnnDataType<PRECISION(kFP16)>() {
return CUDNN_DATA_HALF;
}
namespace cudnn {} // namespace cudnn
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -25,10 +25,97 @@ namespace paddle {
namespace lite {
namespace cuda {
namespace math {
namespace cudnn {
template <lite_api::PrecisionType PType>
cudnnDataType_t GetCudnnDataType();
template <typename T>
class cudnnTypeWrapper;
template <>
class cudnnTypeWrapper<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0f;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0f;
return &v;
}
};
template <>
class cudnnTypeWrapper<half> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const half ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = __float2half(1.0f);
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = __float2half(0.0f);
return &v;
}
};
struct ParamsRegion {
ParamsRegion() : offset_(nullptr), size_(0) {}
ParamsRegion(void* offset, size_t size) : offset_(offset), size_(size) {}
~ParamsRegion() {}
ParamsRegion& operator=(const ParamsRegion& right) {
offset_ = right.offset_;
size_ = right.size_;
return *this;
}
bool operator==(const ParamsRegion& right) {
bool comp_eq = true;
comp_eq = comp_eq && (offset_ == right.offset_);
comp_eq = comp_eq && (size_ = right.size_);
return comp_eq;
}
void* offset_;
size_t size_;
};
template <typename T>
class TensorDescriptors {
public:
TensorDescriptors(size_t n,
const std::vector<std::vector<int>>& dim,
const std::vector<std::vector<int>>& stride) {
descs_.resize(n);
CHECK_EQ(dim.size(), stride.size())
<< "dim size should be equal to stride size";
for (size_t i = 0; i < n; ++i) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&descs_[i]));
CUDNN_CHECK(cudnnSetTensorNdDescriptor(descs_[i],
cudnnTypeWrapper<T>::type,
dim[i].size(),
dim[i].data(),
stride[i].data()));
}
}
~TensorDescriptors() {
for (auto desc : descs_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc));
}
}
const cudnnTensorDescriptor_t* descs() const { return descs_.data(); }
int size() const { return descs_.size(); }
private:
std::vector<cudnnTensorDescriptor_t> descs_;
};
} // namespace cudnn
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -54,7 +54,7 @@ bool CudnnSoftmax<T, Ptype>::Create(const operators::SoftmaxParam& param,
const int stride_c = H * stride_h;
const int stride_n = C * stride_c;
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(bottom_desc_,
GetCudnnDataType<Ptype>(),
cudnn::cudnnTypeWrapper<T>::type,
N,
C,
H,
......@@ -64,7 +64,7 @@ bool CudnnSoftmax<T, Ptype>::Create(const operators::SoftmaxParam& param,
stride_h,
stride_w));
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(top_desc_,
GetCudnnDataType<Ptype>(),
cudnn::cudnnTypeWrapper<T>::type,
N,
C,
H,
......
......@@ -30,17 +30,12 @@ __global__ void CopyMatrixRowsKernel(const T* src,
int height,
int width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int row_id = blockDim.y * blockIdx.x + idy;
if (row_id < height) {
int src_idx = is_src_index ? index[row_id] : row_id;
int dst_idx = is_src_index ? row_id : index[row_id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += blockDim.x) {
dst_data[i] = src_data[i];
}
CUDA_KERNEL_LOOP(tid, height * width) {
int row = tid / width;
int idx = tid % width;
int src_row = is_src_index ? index[row] : row;
int dst_row = is_src_index ? row : index[row];
dst[dst_row * width + idx] = src[src_row * width + idx];
}
}
......@@ -69,9 +64,8 @@ void CopyMatrixRowsFunctor<T>::operator()(
sizeof(uint64_t) * index_lod.size(),
IoDirection::HtoD,
stream);
dim3 threads(128, 8);
dim3 grids((height + threads.y - 1) / threads.y);
CopyMatrixRowsKernel<T><<<grids, threads, 0, stream>>>(
CopyMatrixRowsKernel<
T><<<CUDA_GET_BLOCKS(height * width), CUDA_NUM_THREADS, 0, stream>>>(
src_data, dst_data, index_tensor_data, height, width, is_src_index);
CUDA_POST_KERNEL_CHECK;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence_helper.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
__global__ void Map2Out(
Dtype* output, const Dtype* input, const int* map, int count, int lastdim) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < count) {
int seq = tid / lastdim;
output[map[seq] * lastdim + tid % lastdim] = input[tid];
}
}
template <typename Dtype>
__global__ void Map2In(
Dtype* output, const Dtype* input, const int* map, int count, int lastdim) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < count) {
int seq = tid / lastdim;
output[tid] = input[map[seq] * lastdim + tid % lastdim];
}
}
template <typename Dtype>
void Map2OutFunc(const Dtype* input,
Dtype* output,
int word_size,
int seq_sum,
cudaStream_t stream,
int* dev_map_vec) {
int count = seq_sum * word_size;
int block_dim = count;
int grid_dim = 1;
if (count > 1024) {
block_dim = 256;
grid_dim = (count + block_dim - 1) / block_dim;
}
Map2Out<<<grid_dim, block_dim, 0, stream>>>(
output, input, dev_map_vec, count, word_size);
}
template <typename Dtype>
void Map2InFunc(const Dtype* input,
Dtype* output,
int hidden_size,
int seq_sum,
cudaStream_t stream,
int* dev_map_vec) {
int count = seq_sum * hidden_size;
int block_dim = count;
int grid_dim = 1;
if (count > 1024) {
block_dim = 256;
grid_dim = (count + block_dim - 1) / block_dim;
}
Map2In<<<grid_dim, block_dim, 0, stream>>>(
output, input, dev_map_vec, count, hidden_size);
}
template <typename Dtype>
void SeqSortedseqTranseUtil::Seq2SortedSeq(const Dtype* input,
Dtype* output,
int word_size,
cudaStream_t stream) {
int seq_sum = map_vec_.size();
Map2OutFunc(input, output, word_size, seq_sum, stream, dev_map_vec_);
}
template <typename Dtype>
void SeqSortedseqTranseUtil::SortedSeq2Seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream) {
int seq_sum = map_vec_.size();
Map2InFunc(input, output, hidden_size, seq_sum, stream, dev_map_vec_);
}
bool SeqSortedseqTranseUtil::GetSortedMap(const std::vector<int>& offset_vec,
cudaStream_t stream_id) {
int batch_size = offset_vec.size() - 1;
int word_sum = offset_vec[offset_vec.size() - 1];
std::vector<int> length_vec(batch_size);
length_index_.resize(batch_size);
int emit_length = 0;
if (batch_size == 1) {
emit_length = offset_vec[1] - offset_vec[0];
emit_offset_vec_.resize(emit_length + 1);
for (int i = 0; i <= emit_length; ++i) {
emit_offset_vec_[i] = i;
}
return false;
}
int max_len = 0;
for (int i = 0; i < offset_vec.size() - 1; ++i) {
int len = offset_vec[i + 1] - offset_vec[i];
max_len = max_len > len ? max_len : len;
length_vec[i] = len;
length_index_[i] = i;
}
emit_length = max_len;
if (max_len == 1) {
emit_offset_vec_.resize(2);
emit_offset_vec_[0] = 0;
emit_offset_vec_[1] = emit_length * batch_size;
return false;
}
std::stable_sort(length_index_.begin(),
length_index_.end(),
[&length_vec](int i1, int i2) {
return length_vec[i1] > length_vec[i2];
});
emit_offset_vec_.resize(max_len + 1);
map_vec_.resize(word_sum);
if (word_sum > dev_map_vec_length_) {
if (dev_map_vec_ != nullptr) {
TargetWrapperCuda::Free(static_cast<void*>(dev_map_vec_));
}
dev_map_vec_ =
static_cast<int*>(TargetWrapperCuda::Malloc(sizeof(int) * word_sum));
dev_map_vec_length_ = word_sum;
}
int target_word_id = 0;
std::vector<int> length_vec_cnt = length_vec;
int last_batch_size = batch_size;
for (int word_id_in_seq = 0; word_id_in_seq < max_len; word_id_in_seq++) {
emit_offset_vec_[word_id_in_seq] = target_word_id;
for (int batch_id = 0; batch_id < last_batch_size; batch_id++) {
int old_batch_id = length_index_[batch_id];
if (length_vec_cnt[old_batch_id] > 0) {
int inner_word_id_in_seq = word_id_in_seq;
if (is_reverse_) {
inner_word_id_in_seq = length_vec[old_batch_id] - 1 - word_id_in_seq;
}
int old_word_id = offset_vec[old_batch_id] + inner_word_id_in_seq;
map_vec_[old_word_id] = target_word_id;
length_vec_cnt[old_batch_id]--;
target_word_id++;
} else {
last_batch_size--;
break;
}
}
}
TargetWrapperCuda::MemcpyAsync(dev_map_vec_,
map_vec_.data(),
sizeof(int) * word_sum,
IoDirection::HtoD,
stream_id);
emit_offset_vec_[max_len] = word_sum;
emit_length_ = emit_length;
return true;
}
template void SeqSortedseqTranseUtil::Seq2SortedSeq(const float* input,
float* output,
int word_size,
cudaStream_t stream);
template void SeqSortedseqTranseUtil::SortedSeq2Seq(const float* input,
float* output,
int hidden_size,
cudaStream_t stream);
template void SeqSortedseqTranseUtil::Seq2SortedSeq(const half* input,
half* output,
int word_size,
cudaStream_t stream);
template void SeqSortedseqTranseUtil::SortedSeq2Seq(const half* input,
half* output,
int hidden_size,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
class SeqSortedseqTranseUtil {
public:
explicit SeqSortedseqTranseUtil(bool is_reverse = false, bool is_bi = false)
: is_reverse_(is_reverse),
is_bi_(is_bi),
dev_map_vec_(nullptr),
dev_map_vec_length_(0) {}
~SeqSortedseqTranseUtil() {
if (dev_map_vec_ != nullptr) {
TargetWrapperCuda::Free(static_cast<void*>(dev_map_vec_));
}
}
std::vector<int>& GetLengthIndex() { return length_index_; }
std::vector<int>& GetEmitOffsetVec() { return emit_offset_vec_; }
std::vector<int>& GetMapVec() { return map_vec_; }
int* GetDevMapVec() { return dev_map_vec_; }
int GetEmitLength() { return emit_length_; }
template <typename Dtype>
void Seq2SortedSeq(const Dtype* input,
Dtype* output,
int word_size,
cudaStream_t stream);
template <typename Dtype>
void SortedSeq2Seq(const Dtype* input,
Dtype* output,
int hidden_size,
cudaStream_t stream);
bool GetSortedMap(const std::vector<int>& offset_vec, cudaStream_t stream_id);
private:
std::vector<int> length_index_;
std::vector<int> emit_offset_vec_;
std::vector<int> map_vec_;
int emit_length_;
bool is_reverse_;
bool is_bi_;
int* dev_map_vec_;
int dev_map_vec_length_;
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -76,6 +76,7 @@ bool Device::Build(std::vector<ge::Operator>& input_nodes, // NOLINT
}
}
VLOG(3) << "Getting input node size " << input_nodes.size();
VLOG(3) << "Getting output node size " << output_nodes.size();
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
// Build IR model
......
......@@ -96,7 +96,9 @@ bool AclModelClient::GetModelIOTensorDim(
ACL_CALL(aclmdlGetInputDims(model_desc_, i, &input_dim));
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
aclFormat data_format = aclmdlGetInputFormat(model_desc_, i);
TensorDesc tensor_desc = TensorDesc(data_type, input_dim, data_format);
const std::string name_str(aclmdlGetInputNameByIndex(model_desc_, i));
TensorDesc tensor_desc =
TensorDesc(name_str, data_type, input_dim, data_format);
input_tensor->push_back(tensor_desc);
}
......@@ -108,7 +110,9 @@ bool AclModelClient::GetModelIOTensorDim(
ACL_CALL(aclmdlGetOutputDims(model_desc_, i, &output_dim));
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
aclFormat data_format = aclmdlGetOutputFormat(model_desc_, i);
TensorDesc tensor_desc = TensorDesc(data_type, output_dim, data_format);
const std::string name_str(aclmdlGetOutputNameByIndex(model_desc_, i));
TensorDesc tensor_desc =
TensorDesc(name_str, data_type, output_dim, data_format);
output_tensor->push_back(tensor_desc);
}
return true;
......@@ -118,12 +122,10 @@ bool AclModelClient::GetTensorFromDataset(
std::vector<std::shared_ptr<ge::Tensor>>* output_tensor) {
size_t device_output_num = aclmdlGetDatasetNumBuffers(output_dataset_);
size_t tensor_output_num = reinterpret_cast<size_t>(output_tensor->size());
if (device_output_num != tensor_output_num) {
LOG(ERROR)
<< "[HUAWEI_ASCEND_NPU] output number not equal, device number is "
<< device_output_num << "tensor number is " << tensor_output_num;
return false;
}
CHECK_EQ(device_output_num, tensor_output_num)
<< "[HUAWEI_ASCEND_NPU] tensor output number should equal to device "
"output number, device output number is "
<< device_output_num << ", tensor output number is " << tensor_output_num;
for (size_t i = 0; i < device_output_num; i++) {
aclDataBuffer* buffer_device = aclmdlGetDatasetBuffer(output_dataset_, i);
void* device_data = aclGetDataBufferAddr(buffer_device);
......@@ -195,7 +197,10 @@ void AclModelClient::CreateOutputDataset(
return;
}
size_t output_size = aclmdlGetNumOutputs(model_desc_);
CHECK_EQ(output_size, output_tensor->size());
CHECK_EQ(output_size, output_tensor->size())
<< "[HUAWEI_ASCEND_NPU] model output number should equal to output "
"tensor size, model output number is "
<< output_size << ", output tensor number is " << output_tensor->size();
for (size_t i = 0; i < output_size; i++) {
size_t buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
void* buffer_device = nullptr;
......
......@@ -25,15 +25,20 @@ namespace huawei_ascend_npu {
class TensorDesc {
public:
TensorDesc(aclDataType data_type, aclmdlIODims dims, aclFormat format) {
TensorDesc(const std::string name,
aclDataType data_type,
aclmdlIODims dims,
aclFormat format) {
if (format == ACL_FORMAT_NHWC) {
dim_order[1] = 3;
dim_order[2] = 1;
dim_order[3] = 2;
}
// create ge::Tensordesc
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting tensor name : " << name;
ge_tensor_desc_ = new ge::TensorDesc(
GetGeShape(dims), GetGeFormat(format), GetGeDataType(data_type));
ge_tensor_desc_->SetName(name);
CHECK(ge_tensor_desc_ != nullptr);
VLOG(3) << "[HUAWEI_ASCEND_NPU] Getting data shape : " << repr();
}
......
......@@ -133,7 +133,7 @@ lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper)
lite_cc_library(program SRCS program.cc
DEPS op kernel model_parser ${ops} ${cpp_wrapper}
PROFILE_DEPS lite_profiler
CUDA_DEPS nvtx_wrapper)
CUDA_DEPS nvtx_wrapper cuda_type_trans)
if (NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program)
......
......@@ -44,7 +44,7 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
#ifndef LITE_WITH_MLU
#if (!defined(LITE_WITH_MLU) && !defined(LITE_WITH_HUAWEI_ASCEND_NPU))
.ExcludeTargets({TARGET(kX86)})
#endif
.ExcludeTargets({TARGET(kBM)})
......
......@@ -29,6 +29,10 @@ namespace kernels {
namespace apu {
bool SubgraphEngine::BuildDeviceProgram() {
if (!origin_program_) {
BuildOriginProgram();
}
unsigned int version;
Neuron_getVersion(&version);
VLOG(3) << "Neuron Adapter version: " << version;
......@@ -46,9 +50,6 @@ bool SubgraphEngine::BuildDeviceProgram() {
// Convert all of ops and their input vars and weights and added into the APU
// NIR graph
if (!origin_program_) {
BuildOriginProgram();
}
const auto& bridges = subgraph::Registry::Instance();
const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
......
......@@ -40,6 +40,11 @@ void CastCompute::Run() {
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
} else if (param.in_dtype == param.out_dtype &&
param.in_dtype == 3) { // int64->int64
const auto* x_data = param.X->data<int64_t>();
auto* o_data = param.Out->mutable_data<int64_t>();
memcpy(o_data, x_data, sizeof(int64_t) * param.X->numel());
} else if (param.in_dtype == 21 && param.out_dtype == 5) { // int8->float32
const char* x_data_begin = param.X->data<char>();
const char* x_data_end = x_data_begin + param.X->numel();
......@@ -56,7 +61,7 @@ void CastCompute::Run() {
float* out_data = param.Out->mutable_data<float>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 2) {
} else if (param.in_dtype == 3 && param.out_dtype == 2) { // int64->int32
const int64_t* x_data_begin = param.X->data<int64_t>();
const int64_t* x_data_end = x_data_begin + param.X->numel();
int32_t* out_data = param.Out->mutable_data<int32_t>();
......@@ -72,6 +77,12 @@ void CastCompute::Run() {
const int64_t* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>();
std::transform(x_data_begin, x_data_end, out_data, TransOp<int64_t, float>);
} else if (param.in_dtype == 2 && param.out_dtype == 3) { // INT32 -> INT64
const int32_t* x_data_begin = param.X->data<int32_t>();
const int32_t* x_data_end = x_data_begin + param.X->numel();
int64_t* out_data = param.Out->mutable_data<int64_t>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<int32_t, int64_t>);
} else {
LOG(FATAL) << "other has not been implemented transform with dtype"
<< param.in_dtype << " X, dtype" << param.out_dtype << " Out";
......
......@@ -169,21 +169,47 @@ void SliceCompute<T, PType>::Run() {
using slice_float =
paddle::lite::kernels::arm::SliceCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
using slice_int32 =
paddle::lite::kernels::arm::SliceCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt32, kNCHW, slice_int32, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
using slice_int64 =
paddle::lite::kernels::arm::SliceCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt64, kNCHW, slice_int64, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("EndsTensorList",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
......@@ -21,9 +21,10 @@ namespace lite {
namespace kernels {
namespace arm {
void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>();
template <typename T, PrecisionType PType>
void SplitCompute<T, PType>::Run() {
auto& param = this->template Param<operators::SplitParam>();
const T* din = param.x->template data<T>();
auto& dout = param.output;
auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size());
......@@ -42,12 +43,24 @@ void SplitCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
using split_float =
paddle::lite::kernels::arm::SplitCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, split_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
using split_int64 =
paddle::lite::kernels::arm::SplitCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(split, kARM, kInt64, kNCHW, split_int64, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
......@@ -22,7 +22,8 @@ namespace lite {
namespace kernels {
namespace arm {
class SplitCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class SplitCompute : public KernelLite<TARGET(kARM), PType> {
public:
void Run() override;
......
......@@ -93,13 +93,13 @@ void split_compute_ref(const operators::SplitParam& param) {
}
TEST(split_arm, init) {
SplitCompute split;
SplitCompute<float, PRECISION(kFloat)> split;
ASSERT_EQ(split.precision(), PRECISION(kFloat));
ASSERT_EQ(split.target(), TARGET(kARM));
}
TEST(split_arm, compute) {
SplitCompute split;
SplitCompute<float, PRECISION(kFloat)> split;
operators::SplitParam param;
lite::Tensor x;
......
......@@ -14,6 +14,7 @@
#include "lite/kernels/cuda/gru_compute.h"
#include <string>
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/bias.h"
......@@ -273,6 +274,8 @@ void GRUCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto* input = param.input;
T* x_data =
const_cast<lite::Tensor*>(input)->template mutable_data<T>(TARGET(kCUDA));
lite::Tensor* h0{nullptr};
if (param.h0) {
h0 = const_cast<lite::Tensor*>(param.h0);
......@@ -289,7 +292,7 @@ void GRUCompute<T, PType>::Run() {
lite::Tensor* hidden = param.hidden;
T* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
T* out_data = hidden->template mutable_data<T>(TARGET(kCUDA));
T* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
T* batch_hidden_data = batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
......@@ -300,14 +303,28 @@ void GRUCompute<T, PType>::Run() {
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
lite::cuda::math::LoDTensor2BatchFunctor<T> batch_func;
batch_func(*input, batch_gate, is_reverse, stream);
LoD offset_vec_vec = input->lod();
std::vector<int> offset(offset_vec_vec[offset_vec_vec.size() - 1].size());
for (size_t i = 0; i < offset_vec_vec[offset_vec_vec.size() - 1].size();
++i) {
offset[i] = static_cast<int>(offset_vec_vec[offset_vec_vec.size() - 1][i]);
}
bool need_process = seq_utils_.GetSortedMap(offset, stream);
int emit_length = seq_utils_.GetEmitOffsetVec().size() - 1;
auto emit_offset_vec = seq_utils_.GetEmitOffsetVec();
if (need_process) {
seq_utils_.Seq2SortedSeq(
input->template data<T>(), batch_gate_data, 3 * frame_size, stream);
x_data = batch_gate_data;
out_data = batch_hidden_data;
}
if (bias) {
// TODO(wilber): validate when bias is not nullptr
lite::cuda::math::RowwiseAdd<T> add_bias;
add_bias(batch_gate_data,
add_bias(x_data,
bias->template data<T>(),
batch_gate_data,
x_data,
frame_size,
batch_gate->numel(),
stream);
......@@ -320,6 +337,7 @@ void GRUCompute<T, PType>::Run() {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
// TODO(wilber): validate when h0 is not nullptr
ordered_h0_.Resize(h0->dims());
lite::cuda::math::CopyMatrixRowsFunctor<T> row_shuffle;
row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream);
......@@ -327,15 +345,13 @@ void GRUCompute<T, PType>::Run() {
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; ++n) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
for (size_t n = 0; n < emit_length; ++n) {
int bstart = emit_offset_vec[n];
int bend = emit_offset_vec[n + 1];
int cur_batch_size = bend - bstart;
gru_value.output_value = batch_hidden_data + bstart * frame_size;
gru_value.gate_value = batch_gate_data + bstart * frame_size * 3;
gru_value.output_value = out_data + bstart * frame_size;
gru_value.gate_value = x_data + bstart * frame_size * 3;
gru_value.reset_output_value =
batch_reset_hidden_prev_data + bstart * frame_size;
......@@ -349,10 +365,13 @@ void GRUCompute<T, PType>::Run() {
&context);
gru_value.prev_out_value = gru_value.output_value;
}
lite::cuda::math::Batch2LoDTensorFunctor<T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(*batch_hidden, hidden, stream);
if (need_process) {
seq_utils_.SortedSeq2Seq(batch_hidden_data,
hidden->mutable_data<T>(TARGET(kCUDA)),
frame_size,
stream);
}
hidden->set_lod(input->lod());
}
} // namespace cuda
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/backends/cuda/math/sequence_helper.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
......@@ -38,6 +39,7 @@ class GRUCompute : public KernelLite<TARGET(kCUDA), PType> {
private:
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
lite::Tensor ordered_h0_;
lite::cuda::math::SeqSortedseqTranseUtil seq_utils_;
};
} // namespace cuda
......
......@@ -11,6 +11,9 @@ lite_cc_library(subgraph_bridge_act_op_huawei_ascend_npu SRCS act_op.cc DEPS ${h
lite_cc_library(subgraph_bridge_conv_op_huawei_ascend_npu SRCS conv_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_interpolate_op_huawei_ascend_npu SRCS interpolate_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_concat_op_huawei_ascend_npu SRCS concat_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_pool_op_huawei_ascend_npu SRCS pool_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_elementwise_ops_huawei_ascend_npu SRCS elementwise_ops.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_registry
......@@ -20,4 +23,7 @@ set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_conv_op_huawei_ascend_npu
subgraph_bridge_interpolate_op_huawei_ascend_npu
subgraph_bridge_concat_op_huawei_ascend_npu
subgraph_bridge_pool_op_huawei_ascend_npu
subgraph_bridge_elementwise_ops_huawei_ascend_npu
subgraph_bridge_batch_norm_op_huawei_ascend_npu
CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges")
......@@ -49,10 +49,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto act_node = graph->template Add<ActType>(out_name);
auto act_op = act_node->template data<ActType>();
act_op->set_input_x(*x_node->data());
TENSOR_UPDATE_INPUT(
act_op, x, ge::FORMAT_NCHW, CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_OUTPUT(
act_op, y, ge::FORMAT_NCHW, CvtPrecisionType(act_node->precision()));
INPUT_UPDATE(act_op, x, x_node);
OUTPUT_UPDATE(act_op, y, act_node);
return SUCCESS;
}
......@@ -88,10 +86,8 @@ int ActConverter<ge::op::LeakyRelu>(void* ctx, OpLite* op, KernelBase* kernel) {
// only for leaky_relu
auto alpha = op_info->GetAttr<float>("alpha");
act_op->set_attr_negative_slope(alpha);
TENSOR_UPDATE_INPUT(
act_op, x, ge::FORMAT_NCHW, CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_OUTPUT(
act_op, y, ge::FORMAT_NCHW, CvtPrecisionType(act_node->precision()));
INPUT_UPDATE(act_op, x, x_node);
OUTPUT_UPDATE(act_op, y, act_node);
return SUCCESS;
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input data nodes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto scale_name = op_info->Input("Scale").front();
auto scale = scope->FindMutableTensor(scale_name);
auto bias_name = op_info->Input("Bias").front();
auto bias = scope->FindMutableTensor(bias_name);
auto mean_name = op_info->Input("Mean").front();
auto mean = scope->FindMutableTensor(mean_name);
auto variance_name = op_info->Input("Variance").front();
auto variance = scope->FindMutableTensor(variance_name);
// Get output var nodes
auto y_name = op_info->Output("Y").front();
// Get attributes
float epsilon = op_info->GetAttr<float>("epsilon");
// Check is_test
auto is_test_type = op_info->GetAttrType("is_test");
if (is_test_type == OpDescAPI::AttrType::INT) {
CHECK_EQ(op_info->GetAttr<int>("is_test"), 1)
<< "[HUAWEI_ASCEND_NPU] Only is_test=1 or is_test=true is supported in "
"inference mode.";
} else if (is_test_type == OpDescAPI::AttrType::BOOLEAN) {
CHECK_EQ(op_info->GetAttr<bool>("is_test"), true)
<< "[HUAWEI_ASCEND_NPU] Only is_test=1 or is_test=true is supported in "
"inference mode.";
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Scale, Bias, Mean, Variance node
auto scale_node = graph->Add(scale_name, *scale);
auto bias_node = graph->Add(bias_name, *bias);
auto mean_node = graph->Add(mean_name, *mean);
auto variance_node = graph->Add(variance_name, *variance);
// Batch Norm node - output nodes
auto batch_norm_node = graph->Add<ge::op::BatchNorm>(y_name + "/batch_norm");
auto batch_norm_op = batch_norm_node->data<ge::op::BatchNorm>();
batch_norm_op->set_input_x(*x_node->data());
batch_norm_op->set_input_scale(*scale_node->data());
batch_norm_op->set_input_offset(*bias_node->data());
batch_norm_op->set_input_mean(*mean_node->data());
batch_norm_op->set_input_variance(*variance_node->data());
batch_norm_op->set_attr_epsilon(epsilon);
batch_norm_op->set_attr_data_format("NCHW");
batch_norm_op->set_attr_is_training(false);
INPUT_UPDATE(batch_norm_op, x, x_node);
INPUT_UPDATE(batch_norm_op, scale, scale_node);
INPUT_UPDATE(batch_norm_op, offset, bias_node);
INPUT_UPDATE(batch_norm_op, mean, mean_node);
INPUT_UPDATE(batch_norm_op, variance, variance_node);
OUTPUT_UPDATE(batch_norm_op, y, batch_norm_node);
// Create Variable node for batch norm output y
auto out_y_node = graph->Add<ge::op::Identity>(y_name);
auto out_y_op = out_y_node->data<ge::op::Identity>();
out_y_op->set_input_x(*batch_norm_node->data(), "y");
INPUT_UPDATE(out_y_op, x, batch_norm_node);
OUTPUT_UPDATE(out_y_op, y, out_y_node);
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
batch_norm,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::BatchNormConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -51,10 +51,8 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto concat_op = concat_node->data<ge::op::Concat>();
// set axis input
concat_op->set_input_concat_dim(*axis_node->data());
TENSOR_UPDATE_INPUT(concat_op,
concat_dim,
ge::FORMAT_NCHW,
CvtPrecisionType(axis_node->precision()));
INPUT_UPDATE(concat_op, concat_dim, axis_node);
// set dynamic input
concat_op->set_attr_N(num);
concat_op->create_dynamic_input_x(num);
......@@ -69,17 +67,10 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = graph->Add(x_name, *x);
}
concat_op->set_dynamic_input_x(idx, *x_node->data());
TENSOR_UPDATE_DYNAMIC_INPUT(concat_op,
x,
idx,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
DYNAMIC_INPUT_UPDATE(concat_op, x, idx, x_node);
idx++;
}
TENSOR_UPDATE_OUTPUT(concat_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(concat_node->precision()));
OUTPUT_UPDATE(concat_op, y, concat_node);
} else {
auto concat_node = graph->Add<ge::op::ConcatD>(out_name);
auto concat_op = concat_node->data<ge::op::ConcatD>();
......@@ -97,17 +88,10 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = graph->Add(x_name, *x);
}
concat_op->set_dynamic_input_x(idx, *x_node->data());
TENSOR_UPDATE_DYNAMIC_INPUT(concat_op,
x,
idx,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
DYNAMIC_INPUT_UPDATE(concat_op, x, idx, x_node);
idx++;
}
TENSOR_UPDATE_OUTPUT(concat_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(concat_node->precision()));
OUTPUT_UPDATE(concat_op, y, concat_node);
}
return SUCCESS;
......
......@@ -182,19 +182,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_attr_data_format("NCHW");
if (bias_node != nullptr && is_channel_bias) {
conv_op->set_input_bias(*bias_node->data());
TENSOR_UPDATE_INPUT(conv_op,
bias,
ge::FORMAT_NCHW,
CvtPrecisionType(bias_node->precision()));
INPUT_UPDATE(conv_op, bias, bias_node);
}
TENSOR_UPDATE_INPUT(
conv_op, x, ge::FORMAT_NCHW, CvtPrecisionType(input_node->precision()));
TENSOR_UPDATE_INPUT(conv_op,
filter,
ge::FORMAT_NCHW,
CvtPrecisionType(filter_node->precision()));
TENSOR_UPDATE_OUTPUT(
conv_op, y, ge::FORMAT_NCHW, CvtPrecisionType(conv_node->precision()));
INPUT_UPDATE(conv_op, x, input_node);
INPUT_UPDATE(conv_op, filter, filter_node);
OUTPUT_UPDATE(conv_op, y, conv_node);
} else {
conv_node = graph->Add<ge::op::Conv2D>(output_name);
auto conv_op = conv_node->data<ge::op::Conv2D>();
......@@ -210,19 +202,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_attr_data_format("NCHW");
if (bias_node != nullptr && is_channel_bias) {
conv_op->set_input_bias(*bias_node->data());
TENSOR_UPDATE_INPUT(conv_op,
bias,
ge::FORMAT_NCHW,
CvtPrecisionType(bias_node->precision()));
INPUT_UPDATE(conv_op, bias, bias_node);
}
TENSOR_UPDATE_INPUT(
conv_op, x, ge::FORMAT_NCHW, CvtPrecisionType(input_node->precision()));
TENSOR_UPDATE_INPUT(conv_op,
filter,
ge::FORMAT_NCHW,
CvtPrecisionType(filter_node->precision()));
TENSOR_UPDATE_OUTPUT(
conv_op, y, ge::FORMAT_NCHW, CvtPrecisionType(conv_node->precision()));
INPUT_UPDATE(conv_op, x, input_node);
INPUT_UPDATE(conv_op, filter, filter_node);
OUTPUT_UPDATE(conv_op, y, conv_node);
}
// append Add node to support bias
if (bias_node != nullptr && !is_channel_bias) {
......@@ -230,7 +214,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto add_op = add_node->data<ge::op::Add>();
add_op->set_input_x1(*conv_node->data());
add_op->set_input_x2(*bias_node->data());
conv_node = add_node;
INPUT_UPDATE(add_op, x1, conv_node);
INPUT_UPDATE(add_op, x2, bias_node);
OUTPUT_UPDATE(add_op, y, add_node);
}
CHECK(conv_node);
......@@ -241,11 +227,15 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto act_node = graph->Add<ge::op::Relu>(output_name);
auto act_op = act_node->data<ge::op::Relu>();
act_op->set_input_x(*conv_node->data());
INPUT_UPDATE(act_op, x, conv_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else if (act_type == "leaky_relu") {
auto act_node = graph->Add<ge::op::LeakyRelu>(output_name);
auto act_op = act_node->data<ge::op::LeakyRelu>();
act_op->set_input_x(*conv_node->data());
act_op->set_attr_negative_slope(leaky_relu_alpha);
INPUT_UPDATE(act_op, x, conv_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] act type not supported: "
<< act_type;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
void CvtXYShape(std::vector<int64_t>* x_shape,
std::vector<int64_t>* y_shape,
int axis) {
int x_shape_size = x_shape->size();
int y_shape_size = y_shape->size();
CHECK_GE(x_shape_size, y_shape_size);
// only support:
// 1. same shape
// 2. (n,c,h,w) * (1,c,1,1)
// 3. (n,c,h,w) * (n,c,1,1)
// 4. (n,c,h,w) * (1,c,h,1)
// 5. (n,c,h,w) * (1,c,h,w)
// 6. (n,c,h,w) * (n,c,1,w)
if (*x_shape == *y_shape) {
*x_shape = CvtShape(*x_shape);
*y_shape = CvtShape(*y_shape);
return;
}
if (y_shape_size == 1) {
for (int i = 0; i < 4 - x_shape_size; i++) {
x_shape->push_back(1);
}
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
*x_shape = std::vector<int64_t>{1, n, c * h * w, 1};
} else if (axis == 2) {
*x_shape = std::vector<int64_t>{n * c, h, w, 1};
} else if (axis == 3) {
*x_shape = std::vector<int64_t>{n * c * h, w, 1, 1};
}
*y_shape = std::vector<int64_t>{1, y_shape->at(0), 1, 1};
return;
}
if (y_shape_size == 2) {
for (int i = 0; i < 4 - x_shape_size; i++) {
x_shape->push_back(1);
}
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
y_shape->insert(y_shape->end(), 2, 1);
} else if (axis == 1) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
} else if (axis == 2) {
*x_shape = std::vector<int64_t>{n * c, h, w, 1};
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
}
return;
}
if (y_shape_size == 3) {
y_shape->insert(y_shape->begin(), 1);
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
*x_shape = std::vector<int64_t>{1, n * c * h, w, 1};
*y_shape = std::vector<int64_t>{1, n * c * h, 1, 1};
}
return;
}
}
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y = scope->FindTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out = scope->FindTensor(out_name);
auto out_dims = out->dims();
auto axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
auto x_new_shape = x_dims.Vectorize();
auto y_new_shape = y_dims.Vectorize();
CvtXYShape(&x_new_shape, &y_new_shape, axis);
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
auto shape_node = graph->Add<int64_t>(x_name + "/shape", x_new_shape);
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_x(*x_node->data());
reshaped_x_op->set_input_shape(*shape_node->data());
reshaped_x_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_x_op, x, x_node);
INPUT_UPDATE(reshaped_x_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_x_op, y, reshaped_x_node);
x_node = reshaped_x_node;
} else {
x_node = graph->Add(x_name, *x, x_new_shape);
}
// Y node
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
auto shape_node = graph->Add<int64_t>(y_name + "/shape", y_new_shape);
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_x(*y_node->data());
reshaped_y_op->set_input_shape(*shape_node->data());
reshaped_y_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_y_op, x, y_node);
INPUT_UPDATE(reshaped_y_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_y_op, y, reshaped_y_node);
y_node = reshaped_y_node;
} else {
y_node = graph->Add(y_name, *y, y_new_shape);
}
// Elementwise node
std::shared_ptr<Node> elt_node = nullptr;
if (op_type == "elementwise_add" ||
op_type == "fusion_elementwise_add_activation") {
elt_node = graph->Add<ge::op::Add>(out_name);
auto elt_op = elt_node->data<ge::op::Add>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_sub" ||
op_type == "fusion_elementwise_sub_activation") {
elt_node = graph->Add<ge::op::Sub>(out_name);
auto elt_op = elt_node->data<ge::op::Sub>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_mul" ||
op_type == "fusion_elementwise_mul_activation") {
elt_node = graph->Add<ge::op::Mul>(out_name);
auto elt_op = elt_node->data<ge::op::Mul>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else if (op_type == "elementwise_div" ||
op_type == "fusion_elementwise_div_activation") {
elt_node = graph->Add<ge::op::RealDiv>(out_name);
auto elt_op = elt_node->data<ge::op::RealDiv>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
INPUT_UPDATE(elt_op, x1, x_node);
INPUT_UPDATE(elt_op, x2, y_node);
OUTPUT_UPDATE(elt_op, y, elt_node);
} else {
LOG(WARNING) << "[NPU] Unsupported op type: " << op_type;
return FAILED;
}
auto out_shape = out_dims.Vectorize();
if (out_shape != x_new_shape) {
auto shape_node = graph->Add<int64_t>(out_name + "/shape", out_shape);
auto reshaped_elt_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_elt_op = reshaped_elt_node->data<ge::op::Reshape>();
reshaped_elt_op->set_input_x(*elt_node->data());
reshaped_elt_op->set_input_shape(*shape_node->data());
reshaped_elt_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_elt_op, x, elt_node);
INPUT_UPDATE(reshaped_elt_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_elt_op, y, reshaped_elt_node);
elt_node = reshaped_elt_node;
}
// Act node
if (op_type == "fusion_elementwise_add_activation" ||
op_type == "fusion_elementwise_sub_activation" ||
op_type == "fusion_elementwise_mul_activation" ||
op_type == "fusion_elementwise_div_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
if (act_type == "leaky_relu") {
auto act_node = graph->Add<ge::op::LeakyRelu>(out_name);
auto act_op = act_node->data<ge::op::LeakyRelu>();
act_op->set_input_x(*elt_node->data());
auto alpha = op_info->GetAttr<float>("alpha");
act_op->set_attr_negative_slope(alpha);
INPUT_UPDATE(act_op, x, elt_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else if (act_type == "relu") {
auto act_node = graph->Add<ge::op::Relu>(out_name);
auto act_op = act_node->data<ge::op::Relu>();
act_op->set_input_x(*elt_node->data());
INPUT_UPDATE(act_op, x, elt_node);
OUTPUT_UPDATE(act_op, y, act_node);
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported act type: " << act_type;
return FAILED;
}
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
elementwise_add,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_sub,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_mul,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
elementwise_div,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_add_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_sub_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_mul_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(
fusion_elementwise_div_activation,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ElementwiseConverter);
......@@ -37,19 +37,29 @@ class Node {
kData,
};
Node(std::shared_ptr<ge::Operator> data,
Node(std::string name,
std::shared_ptr<ge::Operator> data,
PrecisionType precision,
DataLayoutType layout,
Role role)
: data_(data), precision_(precision), layout_(layout), role_(role) {}
Node(PrecisionType precision, DataLayoutType layout, Role role)
: precision_(precision), layout_(layout), role_(role) {}
: name_(name),
data_(data),
precision_(precision),
layout_(layout),
role_(role) {}
Node(std::string name,
PrecisionType precision,
DataLayoutType layout,
Role role)
: name_(name), precision_(precision), layout_(layout), role_(role) {}
void set_name(std::string name) { name_ = name; }
void set_data(std::shared_ptr<ge::Operator> data) { data_ = data; }
void set_precision(PrecisionType precision) { precision_ = precision; }
void set_layout(DataLayoutType layout) { layout_ = layout; }
void set_role(Role role) { role_ = role; }
std::string name() { return name_; }
template <typename T>
std::shared_ptr<T> data() {
return std::static_pointer_cast<T>(data_);
......@@ -62,6 +72,7 @@ class Node {
bool is_data() const { return role_ == Role::kData; }
private:
std::string name_{};
std::shared_ptr<ge::Operator> data_{nullptr};
PrecisionType precision_{PRECISION(kFloat)};
DataLayoutType layout_{DATALAYOUT(kNCHW)};
......@@ -83,10 +94,10 @@ class Graph {
} else if (typeid(T) == typeid(ge::op::Data)) {
role = Node::Role::kData;
}
auto node = std::make_shared<Node>(precision, layout, role);
auto node = std::make_shared<Node>(name, precision, layout, role);
auto idx = Add(name, node);
CHECK_GE(idx, 1);
// Generate a unique name for the created HiAI IR
// Generate a unique name for the created Huawei Ascend NPU IR
node->set_data(
std::make_shared<T>(name + "__" + paddle::lite::to_string(idx)));
return node;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -97,18 +97,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
bilinear_interp_op->set_input_x(*x_node->data());
bilinear_interp_op->set_input_size(*out_size_node->data());
bilinear_interp_op->set_attr_align_corners(align_corners);
TENSOR_UPDATE_INPUT(bilinear_interp_op,
x,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_INPUT(bilinear_interp_op,
size,
ge::FORMAT_NCHW,
CvtPrecisionType(out_size_node->precision()));
TENSOR_UPDATE_OUTPUT(bilinear_interp_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(bilinear_interp_node->precision()));
INPUT_UPDATE(bilinear_interp_op, x, x_node);
INPUT_UPDATE(bilinear_interp_op, size, out_size_node);
OUTPUT_UPDATE(bilinear_interp_op, y, bilinear_interp_node);
} else if (interp_method == "nearest") {
auto nearest_interp_node =
graph->Add<ge::op::ResizeNearestNeighborV2>(out_name);
......@@ -117,18 +108,9 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
nearest_interp_op->set_input_x(*x_node->data());
nearest_interp_op->set_input_size(*out_size_node->data());
nearest_interp_op->set_attr_align_corners(align_corners);
TENSOR_UPDATE_INPUT(nearest_interp_op,
x,
ge::FORMAT_NCHW,
CvtPrecisionType(x_node->precision()));
TENSOR_UPDATE_INPUT(nearest_interp_op,
size,
ge::FORMAT_NCHW,
CvtPrecisionType(out_size_node->precision()));
TENSOR_UPDATE_OUTPUT(nearest_interp_op,
y,
ge::FORMAT_NCHW,
CvtPrecisionType(nearest_interp_node->precision()));
INPUT_UPDATE(nearest_interp_op, x, x_node);
INPUT_UPDATE(nearest_interp_op, size, out_size_node);
OUTPUT_UPDATE(nearest_interp_op, y, nearest_interp_node);
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported interpolate method: "
<< interp_method;
......
......@@ -22,9 +22,18 @@ USE_SUBGRAPH_BRIDGE(relu6, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softsign, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softplus, kHuaweiAscendNPU);
// conv
USE_SUBGRAPH_BRIDGE(conv2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(bilinear_interp, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(nearest_interp, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(concat, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(pool2d, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_add, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_sub, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(elementwise_div, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_sub_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_mul_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pool_op.h"
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
CHECK_EQ(op_info->GetAttr<bool>("exclusive"), true)
<< "[HUAWEI_ASCEND_NPU] Only exclusive=true is supported for Huawei "
"Ascend NPU DDK.";
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// pool mode: 0:max pooling or 1:avg pooling
int mode = 0;
if (pooling_type == "max") {
mode = 0;
} else if (pooling_type == "avg") {
mode = 1;
} else {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Unsupported pooling type: "
<< pooling_type;
return FAILED;
}
// pad algorithm
std::string padding_algorithm("");
if (op_info->HasAttr("padding_algorithm")) {
padding_algorithm = op_info->GetAttr<std::string>("padding_algorithm");
}
// paddings and strides
if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
}
CHECK_EQ(paddings.size(), 4L) << "[HUAWEI_ASCEND_NPU] Paddings size should "
"be the same or twice as the inputs size.";
bool adaptive = false;
if (op_info->HasAttr("adaptive")) {
adaptive = op_info->GetAttr<bool>("adaptive");
}
auto strides = op_info->GetAttr<std::vector<int>>("strides");
lite::operators::UpdatePadding(&paddings,
global_pooling,
adaptive,
padding_algorithm,
x->dims(),
strides,
ksize);
// Ascend restriction: padT should equals padB, and padL should equals padR
CHECK_EQ(paddings[0], paddings[1]) << "[HUAWEI_ASCEND_NPU] Padding top "
"should equals to padding bottom in "
"Huawei Ascend NPU DDK";
CHECK_EQ(paddings[2], paddings[3]) << "[HUAWEI_ASCEND_NPU] Padding left "
"should equals to padding right in "
"Huawei Ascend NPU DDK";
// ceil mode
bool ceil_mode =
op_info->HasAttr("ceil_mode") && op_info->GetAttr<bool>("ceil_mode");
// Pooling node
auto pool_node = graph->Add<ge::op::Pooling>(out_name);
auto pool_op = pool_node->data<ge::op::Pooling>();
pool_op->set_input_x(*x_node->data());
pool_op->set_attr_mode(mode);
pool_op->set_attr_global_pooling(global_pooling);
pool_op->set_attr_window(ge::Operator::OpListInt({ksize[0], ksize[1]}));
pool_op->set_attr_stride(ge::Operator::OpListInt({strides[0], strides[1]}));
pool_op->set_attr_pad(ge::Operator::OpListInt(
{paddings[0], paddings[1], paddings[2], paddings[3]}));
// "0" (ceil mode) or "1" (floor mode). Defaults to "0"
if (!ceil_mode) {
pool_op->set_attr_ceil_mode(1);
}
INPUT_UPDATE(pool_op, x, x_node);
OUTPUT_UPDATE(pool_op, y, pool_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
pool2d,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::PoolConverter);
......@@ -19,9 +19,7 @@
#include <memory>
#include <string>
#include <vector>
// #include "graph/buffer.h"
#include "graph/tensor.h"
#include "graph/types.h"
#include "lite/backends/huawei_ascend_npu/utils.h"
#include "lite/core/op_lite.h"
#include "lite/utils/macros.h"
......@@ -30,16 +28,34 @@ namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
#define TENSOR_UPDATE_INPUT(op, attr, format, dtype) \
ge::TensorDesc _##op##_input_desc_##attr(ge::Shape(), format, dtype); \
#define INPUT_UPDATE(...) TENSOR_INPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define OUTPUT_UPDATE(...) TENSOR_OUTPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define DYNAMIC_INPUT_UPDATE(...) \
TENSOR_DYNAMIC_INPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define DYNAMIC_OUTPUT_UPDATE(...) \
TENSOR_DYNAMIC_OUTPUT_UPDATE(__VA_ARGS__, ge::FORMAT_NCHW)
#define TENSOR_INPUT_UPDATE(op, attr, node, format) \
ge::TensorDesc _##op##_input_desc_##attr( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_input_desc_##attr.SetName(node->name()); \
op->update_input_desc_##attr(_##op##_input_desc_##attr);
#define TENSOR_UPDATE_OUTPUT(op, attr, format, dtype) \
ge::TensorDesc _##op##_output_desc_##attr(ge::Shape(), format, dtype); \
#define TENSOR_OUTPUT_UPDATE(op, attr, node, format) \
ge::TensorDesc _##op##_output_desc_##attr( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_output_desc_##attr.SetName(node->name()); \
op->update_output_desc_##attr(_##op##_output_desc_##attr);
#define TENSOR_UPDATE_DYNAMIC_INPUT(op, attr, idx, format, dtype) \
ge::TensorDesc _##op##_input_desc_##attr##_##idx( \
ge::Shape(), format, dtype); \
#define TENSOR_DYNAMIC_INPUT_UPDATE(op, attr, idx, node, format) \
ge::TensorDesc _##op##_input_desc_##attr##_##idx( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_input_desc_##attr##_##idx.SetName(node->name()); \
op->update_dynamic_input_desc_##attr(idx, _##op##_input_desc_##attr##_##idx);
#define TENSOR_DYNAMIC_OUTPUT_UPDATE(op, attr, idx, node, format) \
ge::TensorDesc _##op##_output_desc_##attr##_##idx( \
ge::Shape(), format, CvtPrecisionType(node->precision())); \
_##op##_output_desc_##attr##_##idx.SetName(node->name()); \
op->update_dynamic_output_desc_##attr(idx, \
_##op##_output_desc_##attr##_##idx);
// Type/tensor converters for converting Paddle type/tensor to HiAI type/tensor
bool HasInputArg(const OpInfo* op_info,
......
......@@ -220,7 +220,7 @@ bool DeviceProgram::ShareBufferWithOriginTensors(
CHECK(!model_name_.empty() && model_client_);
// Query the dimensions of the device input and output tensors if not
// initialized
VLOG(3) << "[HUAWEI_ASCEND_NPU] Sharing buffer with origin tnsors...";
VLOG(3) << "[HUAWEI_ASCEND_NPU] Sharing buffer with origin tensors...";
if (device_idims_.empty() || device_odims_.empty()) {
if (!(model_client_->GetModelIOTensorDim(&device_idims_, &device_odims_))) {
LOG(WARNING)
......
......@@ -153,11 +153,19 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
LOG(FATAL) << "Unsupported attr type found " << static_cast<int>(type);
}
};
// On arm backend, some op attributes have no effect on inference process, we
// abandoned these attributes to reduce model_size and run-time memory usage.
// This process is operated on opt tool, so it will not increase
// initialization time.
std::vector<std::string> skiped_attributes = {"op_callstack",
"op_namescope",
"op_role",
"workspace_size_MB",
"op_role_var"};
for (const auto &attr_name : any_desc.AttrNames()) {
// note: since `op_callstack` attribute has no effect on inference process,
// we will not load it into op_desc.
if (attr_name != "op_callstack") {
auto it = std::find(
skiped_attributes.begin(), skiped_attributes.end(), attr_name);
if (it == skiped_attributes.end()) {
auto type = any_desc.GetAttrType(attr_name);
set_attr(attr_name, type);
}
......
......@@ -10,5 +10,6 @@ lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS fbs_headers)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_fbs_library(fbs_param_desc SRCS param_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc scope)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
lite_cc_test(test_fbs_io SRCS io_test.cc DEPS fbs_io)
......@@ -13,35 +13,62 @@
// limitations under the License.
#include "lite/model_parser/flatbuffers/block_desc.h"
#include <memory>
namespace paddle {
namespace lite {
namespace fbs {
template <>
proto::VarDesc const* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) const {
proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return desc_->vars()->Get(idx);
}
template <>
proto::OpDesc const* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) const {
proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return desc_->ops()->Get(idx);
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
template <>
proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return vars_[idx].raw_desc();
}
template <>
proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() {
desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT));
SyncVars();
return vars_.back().raw_desc();
}
template <>
proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= vars.size()";
return ops_[idx].raw_desc();
}
template <>
proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() {
desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT));
SyncOps();
return ops_.back().raw_desc();
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -25,17 +25,17 @@ namespace paddle {
namespace lite {
namespace fbs {
class BlockDesc : public BlockDescAPI {
class BlockDescView : public BlockDescAPI {
public:
explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) {
explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDesc(desc_->vars()->Get(idx)));
vars_.push_back(VarDescView(desc_->vars()->Get(idx)));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDesc(desc_->ops()->Get(idx)));
ops_.push_back(OpDescView(desc_->ops()->Get(idx)));
}
}
......@@ -69,26 +69,103 @@ class BlockDesc : public BlockDescAPI {
return nullptr;
}
const std::vector<VarDesc>& GetVars() const { return vars_; }
const std::vector<VarDescView>& GetVars() const { return vars_; }
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
BlockDesc() { NotImplemented(); }
BlockDescView() { NotImplemented(); }
private:
proto::BlockDesc const* desc_; // not_own
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
std::vector<VarDescView> vars_;
std::vector<OpDescView> ops_;
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily "
LOG(FATAL) << "The additional interfaces of BlockDescView is temporarily "
"unavailable in read-only mode.";
}
};
class BlockDesc : public BlockDescAPI {
public:
BlockDesc() : owned_(true), desc_(new proto::BlockDescT()) {}
explicit BlockDesc(proto::BlockDescT* desc) : desc_(desc) { CHECK(desc_); }
int32_t Idx() const override { return desc_->idx; }
void SetIdx(int32_t idx) override { desc_->idx = idx; }
int32_t ParentIdx() const override { return desc_->parent_idx; }
void SetParentIdx(int32_t idx) override { desc_->parent_idx = idx; }
size_t VarsSize() const override { return desc_->vars.size(); }
void ClearVars() override {
desc_->vars.clear();
SyncVars();
}
size_t OpsSize() const override { return desc_->ops.size(); }
void ClearOps() override {
desc_->ops.clear();
SyncOps();
}
int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx; }
void SetForwardBlockIdx(int32_t idx_in) override {
desc_->forward_block_idx = idx_in;
}
proto::BlockDescT* raw_desc() { return desc_; }
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T* AddVar();
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T* AddOp();
~BlockDesc() {
if (owned_) {
delete desc_;
}
}
private:
void SyncVars() {
vars_.resize(desc_->vars.size());
for (size_t i = 0; i < desc_->vars.size(); ++i) {
if (vars_[i].raw_desc() != desc_->vars[i].get()) {
vars_[i] = VarDesc(desc_->vars[i].get());
}
}
}
void SyncOps() {
ops_.resize(desc_->ops.size());
for (size_t i = 0; i < desc_->ops.size(); ++i) {
if (ops_[i].raw_desc() != desc_->ops[i].get()) {
ops_[i] = OpDesc(desc_->ops[i].get());
}
}
}
bool owned_{false};
proto::BlockDescT* desc_{nullptr};
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -23,16 +23,22 @@ namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog) {
CHECK(prog);
std::vector<char> LoadFile(const std::string& path) {
FILE* file = fopen(path.c_str(), "rb");
fseek(file, 0, SEEK_END);
int64_t length = ftell(file);
rewind(file);
std::vector<char> buf(length);
CHECK(fread(buf.data(), 1, length, file));
CHECK(fread(buf.data(), 1, length, file) == length);
fclose(file);
return buf;
}
void SaveFile(const std::string& path, const void* src, size_t byte_size) {
CHECK(src);
FILE* file = fopen(path.c_str(), "wb");
CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size);
fclose(file);
prog->Init(std::move(buf));
}
void SetParamWithTensor(const std::string& name,
......@@ -72,6 +78,7 @@ void SetScopeWithCombinedParams(lite::Scope* scope,
SetTensorWithParam(tensor, param);
}
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -25,18 +25,14 @@ namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog);
void SetParamWithTensor(const std::string& name,
const lite::Tensor& tensor,
ParamDescWriteAPI* prog);
void SetTensorWithParam(const lite::Tensor& tensor, ParamDescReadAPI* prog);
std::vector<char> LoadFile(const std::string& path);
void SaveFile(const std::string& path, const void* src, size_t byte_size);
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params);
void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name,
CombinedParamsDescWriteAPI* params);
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params);
} // namespace fbs
} // namespace lite
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <gtest/gtest.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace fbs {
namespace {
template <typename T>
void set_tensor(paddle::lite::Tensor* tensor,
const std::vector<int64_t>& dims) {
auto production =
std::accumulate(begin(dims), end(dims), 1, std::multiplies<int64_t>());
tensor->Resize(dims);
std::vector<T> data;
data.resize(production);
for (size_t i = 0; i < production; ++i) {
data[i] = i / 2.f;
}
std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size());
}
} // namespace
TEST(CombinedParamsDesc, Scope) {
/* --------- Save scope ---------- */
Scope scope;
std::vector<std::string> params_name({"var_0", "var_1"});
// variable 0
Variable* var_0 = scope.Var(params_name[0]);
Tensor* tensor_0 = var_0->GetMutable<Tensor>();
set_tensor<float>(tensor_0, std::vector<int64_t>({3, 2}));
// variable 1
Variable* var_1 = scope.Var(params_name[1]);
Tensor* tensor_1 = var_1->GetMutable<Tensor>();
set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1}));
// Set combined parameters
fbs::CombinedParamsDesc combined_param;
SetCombinedParamsWithScope(scope, params_name, &combined_param);
/* --------- Check scope ---------- */
auto check_params = [&](const CombinedParamsDescReadAPI& desc) {
Scope scope_l;
SetScopeWithCombinedParams(&scope_l, desc);
// variable 0
Variable* var_l0 = scope_l.FindVar(params_name[0]);
CHECK(var_l0);
const Tensor& tensor_l0 = var_l0->Get<Tensor>();
CHECK(TensorCompareWith(*tensor_0, tensor_l0));
// variable 1
Variable* var_l1 = scope_l.FindVar(params_name[1]);
CHECK(var_l1);
const Tensor& tensor_l1 = var_l1->Get<Tensor>();
CHECK(TensorCompareWith(*tensor_1, tensor_l1));
};
check_params(combined_param);
/* --------- Cache scope ---------- */
std::vector<char> cache;
cache.resize(combined_param.buf_size());
std::memcpy(cache.data(), combined_param.data(), combined_param.buf_size());
/* --------- View scope ---------- */
check_params(CombinedParamsDescView(std::move(cache)));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -19,7 +19,7 @@ namespace lite {
namespace fbs {
template <>
std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
if (!it->s()) {
return std::string();
......@@ -28,7 +28,7 @@ std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
}
template <>
std::string OpDesc::GetAttr<std::string>(size_t idx) const {
std::string OpDescView::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) {
return std::string();
......@@ -38,43 +38,43 @@ std::string OpDesc::GetAttr<std::string>(size_t idx) const {
template <>
lite::VectorView<std::string, Flatbuffers>
OpDesc::GetAttr<std::vector<std::string>>(const std::string& name) const {
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings());
}
template <>
VectorView<std::string, Flatbuffers> OpDesc::GetAttr<std::vector<std::string>>(
size_t idx) const {
VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings());
}
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
}
#define GET_ATTRS_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
......@@ -88,6 +88,27 @@ GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTR_IMPL
#undef GET_ATTRS_IMPL
#define ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
return (*GetKeyIterator(name, desc_->attrs))->fb_f__; \
} \
template <> \
void OpDesc::SetAttr(const std::string& name, const T& v) { \
(*GetKeyIterator(name, desc_->attrs))->fb_f__ = v; \
}
ATTR_IMPL(int32_t, i);
ATTR_IMPL(int16_t, block_idx);
ATTR_IMPL(float, f);
ATTR_IMPL(bool, b);
ATTR_IMPL(int64_t, l);
ATTR_IMPL(std::vector<int>, ints);
ATTR_IMPL(std::vector<float>, floats);
ATTR_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTRS_IMPL
} // namespace fbs
} // namespace lite
......
......@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/op_desc.h"
......@@ -29,9 +30,9 @@ namespace paddle {
namespace lite {
namespace fbs {
class OpDesc : public OpDescAPI {
class OpDescView : public OpDescAPI {
public:
explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
explicit OpDescView(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); }
......@@ -137,7 +138,7 @@ class OpDesc : public OpDescAPI {
// caused by different building options.
public:
OpDesc() { NotImplemented(); }
OpDescView() { NotImplemented(); }
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
......@@ -184,7 +185,7 @@ class OpDesc : public OpDescAPI {
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of OpDesc is temporarily "
LOG(FATAL) << "The additional interfaces of OpDescView is temporarily "
"unavailable in read-only mode.";
}
std::string type_;
......@@ -194,6 +195,93 @@ class OpDesc : public OpDescAPI {
std::map<std::string, AttrType> attr_types_;
};
class OpDesc : public OpDescAPI {
public:
OpDesc() : owned_(true), desc_(new proto::OpDescT()) {}
explicit OpDesc(proto::OpDescT* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type; }
void SetType(const std::string& type) override { desc_->type = type; }
std::vector<std::string> Input(const std::string& param) const override {
return (*GetKeyIterator(param, desc_->inputs))->arguments;
}
std::vector<std::string> InputArgumentNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& input : desc_->inputs) {
tmp.push_back(input->parameter);
}
return tmp;
}
void SetInput(const std::string& param,
const std::vector<std::string>& args) override {
std::unique_ptr<proto::OpDesc_::VarT> var(new proto::OpDesc_::VarT);
var->parameter = param;
var->arguments = args;
InsertPair(param, std::move(var), &desc_->inputs);
}
std::vector<std::string> Output(const std::string& param) const override {
return (*GetKeyIterator(param, desc_->outputs))->arguments;
}
std::vector<std::string> OutputArgumentNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& output : desc_->outputs) {
tmp.push_back(output->parameter);
}
return tmp;
}
void SetOutput(const std::string& param,
const std::vector<std::string>& args) override {
std::unique_ptr<proto::OpDesc_::VarT> var(new proto::OpDesc_::VarT);
var->parameter = param;
var->arguments = args;
InsertPair(param, std::move(var), &desc_->outputs);
}
bool HasAttr(const std::string& name) const override {
return HasKey(name, desc_->attrs);
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
return ConvertAttrType((*GetKeyIterator(name, desc_->attrs))->type);
}
std::vector<std::string> AttrNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& attr : desc_->attrs) {
tmp.push_back(attr->name);
}
return tmp;
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
template <typename T>
T GetAttr(const std::string& name) const;
proto::OpDescT* raw_desc() { return desc_; }
~OpDesc() {
if (owned_) {
delete desc_;
}
}
private:
bool owned_{false};
proto::OpDescT* desc_{nullptr};
};
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -86,8 +86,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI {
void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data());
params_.reserve(GetParamsSize());
for (size_t idx = 0; idx < GetParamsSize(); ++idx) {
size_t params_size = desc_->params()->size();
params_.reserve(params_size);
for (size_t idx = 0; idx < params_size; ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx)));
}
}
......@@ -114,6 +115,7 @@ class ParamDesc : public ParamDescAPI {
}
explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT());
lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_);
}
......@@ -165,6 +167,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
raw_buf->UnPackTo(&desc_);
SyncParams();
}
const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
return &params_[idx];
}
......@@ -172,7 +175,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
size_t GetParamsSize() const override { return desc_.params.size(); }
ParamDescWriteAPI* AddParamDesc() override {
desc_.params.push_back(std::unique_ptr<proto::ParamDescT>());
desc_.params.push_back(
std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT));
SyncParams();
return &params_[params_.size() - 1];
}
......
......@@ -19,14 +19,15 @@ namespace lite {
namespace fbs {
template <>
proto::BlockDesc const* ProgramDesc::GetBlock<proto::BlockDesc>(
proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return desc_->blocks()->Get(idx);
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
......
......@@ -26,11 +26,11 @@ namespace paddle {
namespace lite {
namespace fbs {
class ProgramDesc : public ProgramDescAPI {
class ProgramDescView : public ProgramDescAPI {
public:
ProgramDesc() = default;
explicit ProgramDesc(const std::vector<char>& buf) { Init(buf); }
explicit ProgramDesc(std::vector<char>&& buf) {
ProgramDescView() = default;
explicit ProgramDescView(const std::vector<char>& buf) { Init(buf); }
explicit ProgramDescView(std::vector<char>&& buf) {
Init(std::forward<std::vector<char>>(buf));
}
......@@ -50,11 +50,11 @@ class ProgramDesc : public ProgramDescAPI {
desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx)));
}
}
void CopyFrom(const ProgramDesc& other) {
void CopyFrom(const ProgramDescView& other) {
buf_ = other.buf();
Init(buf_);
}
......@@ -70,7 +70,7 @@ class ProgramDesc : public ProgramDescAPI {
return nullptr;
}
const std::vector<BlockDesc>& GetBlocks() const { return blocks_; }
const std::vector<BlockDescView>& GetBlocks() const { return blocks_; }
bool HasVersion() const override { return desc_->version() != nullptr; }
......@@ -86,13 +86,13 @@ class ProgramDesc : public ProgramDescAPI {
private:
proto::ProgramDesc const* desc_;
std::vector<char> buf_;
std::vector<BlockDesc> blocks_;
std::vector<BlockDescView> blocks_;
private:
ProgramDesc& operator=(const ProgramDesc&) = delete;
ProgramDesc(const ProgramDesc&) = delete;
ProgramDescView& operator=(const ProgramDescView&) = delete;
ProgramDescView(const ProgramDescView&) = delete;
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily "
LOG(FATAL) << "The additional interfaces of ProgramDescView is temporarily "
"unavailable in read-only mode.";
}
};
......
......@@ -14,6 +14,11 @@
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/traits.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
......@@ -139,6 +144,71 @@ inline proto::AttrType ConvertAttrType(lite::OpAttrType type) {
#undef CASE
}
template <typename FlatbuffersMapT, typename KeyT = std::string>
KeyT GetKey(const std::unique_ptr<FlatbuffersMapT>& object);
#define GET_KEY_INSTANCE(type, key, key_type) \
template <> \
inline key_type GetKey<proto::type>( \
const std::unique_ptr<proto::type>& object) { \
return object->key; \
}
GET_KEY_INSTANCE(OpDesc_::VarT, parameter, std::string);
GET_KEY_INSTANCE(OpDesc_::AttrT, name, std::string);
#undef GET_KEY_INSTANCE
template <typename MapT, typename KeyT = std::string>
struct CompareLessThanKey {
bool operator()(const std::unique_ptr<MapT>& lhs, const KeyT& rhs) {
return GetKey(lhs) < rhs;
}
bool operator()(const KeyT& lhs, const std::unique_ptr<MapT>& rhs) {
return lhs < GetKey(rhs);
}
};
template <typename MapT>
struct CompareLessThan {
bool operator()(const std::unique_ptr<MapT>& lhs,
const std::unique_ptr<MapT>& rhs) {
return GetKey(lhs) < GetKey(rhs);
}
};
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
typename std::vector<std::unique_ptr<MapT>>::const_iterator GetKeyIterator(
const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
auto iter =
std::lower_bound(vector.begin(), vector.end(), key, CompareFunc());
CHECK(GetKey(*iter) == key);
return iter;
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
void InsertPair(const KeyT& key,
std::unique_ptr<MapT>&& val,
std::vector<std::unique_ptr<MapT>>* vector) {
auto iter =
std::lower_bound(vector->begin(), vector->end(), key, CompareFunc());
vector->insert(iter, std::forward<std::unique_ptr<MapT>>(val));
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
bool HasKey(const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
return std::binary_search(vector.begin(), vector.end(), key, CompareFunc());
}
template <typename MapT, typename CompareFunc = CompareLessThan<MapT>>
void Sort(std::vector<std::unique_ptr<MapT>>* vector) {
std::sort(vector->begin(), vector->end(), CompareFunc());
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -26,9 +26,9 @@ namespace paddle {
namespace lite {
namespace fbs {
class VarDesc : public VarDescAPI {
class VarDescView : public VarDescAPI {
public:
explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {}
explicit VarDescView(proto::VarDesc const* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); }
......@@ -66,18 +66,79 @@ class VarDesc : public VarDescAPI {
// caused by different building options.
public:
VarDesc() { NotImplemented(); }
VarDescView() { NotImplemented(); }
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
LOG(FATAL) << "The additional interfaces of VarDescView is temporarily "
"unavailable in read-only mode.";
}
std::vector<int64_t> shape_;
};
class VarDesc : public VarDescAPI {
public:
VarDesc() : owned_(true), desc_(new proto::VarDescT()) {}
explicit VarDesc(proto::VarDescT* desc) : desc_(desc) {
CHECK(desc_);
InitType();
}
std::string Name() const override { return desc_->name; }
void SetName(std::string name) override { desc_->name = name; }
Type GetType() const override { return ConvertVarType(type_->type); }
void SetType(Type type) override {
CHECK(type == VarDescAPI::Type::LOD_TENSOR);
type_->type = ConvertVarType(type);
}
bool Persistable() const override { return desc_->persistable; }
void SetPersistable(bool persistable) override {
desc_->persistable = persistable;
}
std::vector<int64_t> GetShape() const override {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return type_->lod_tensor->tensor->dims;
}
void SetShape(const std::vector<int64_t>& dims) override {
type_->lod_tensor->tensor->dims = dims;
}
proto::VarDescT* raw_desc() { return desc_; }
~VarDesc() {
if (owned_) {
delete desc_;
}
}
private:
void InitType() {
if (!desc_->type) {
desc_->type = std::unique_ptr<proto::VarTypeT>(new proto::VarTypeT());
desc_->type->lod_tensor =
std::unique_ptr<proto::VarType_::LoDTensorDescT>(
new proto::VarType_::LoDTensorDescT());
desc_->type->lod_tensor->tensor =
std::unique_ptr<proto::VarType_::TensorDescT>(
new proto::VarType_::TensorDescT());
}
type_ = desc_->type.get();
}
bool owned_{false};
proto::VarDescT* desc_{nullptr};
paddle::lite::fbs::proto::VarTypeT* type_{nullptr};
};
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -55,6 +55,8 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasAttr("use_cudnn")) {
param_.use_cudnn = opdesc.GetAttr<bool>("use_cudnn");
}
// TODO(wilber): use cudnn default when compile with cuda.
param_.use_cudnn = true;
return true;
}
......
......@@ -117,10 +117,12 @@ class BatchNormComputeTest : public arena::TestCase {
op_desc->SetInput("Mean", {mean_});
op_desc->SetInput("Variance", {variance_});
op_desc->SetOutput("Y", {output_});
op_desc->SetOutput("MeanOut", {mean_out_});
op_desc->SetOutput("VarianceOut", {variance_out_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
if (!is_test_) {
op_desc->SetOutput("MeanOut", {mean_out_});
op_desc->SetOutput("VarianceOut", {variance_out_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
}
op_desc->SetAttr("epsilon", epsilon_);
op_desc->SetAttr("momentum", momentum_);
op_desc->SetAttr("use_global_stats", use_global_stats_);
......@@ -159,6 +161,9 @@ TEST(BatchNorm, precision) {
Place place;
#if defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
#else
......
......@@ -206,6 +206,11 @@ void TestEltDims(Place place, float abs_error) {
void TestEltTypes(Place place, float abs_error) {
for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
// Huawei Ascend NPU DDK has bugs in div, and not support max yet
if (place == TARGET(kHuaweiAscendNPU) &&
(elt_type == "div" || elt_type == "max")) {
continue;
}
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0);
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1);
}
......@@ -214,6 +219,11 @@ void TestEltTypes(Place place, float abs_error) {
void TestEltFuseAct(Place place, float abs_error) {
for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
// Huawei Ascend NPU DDK has bugs in div, and not support max yet
if (place == TARGET(kHuaweiAscendNPU) &&
(elt_type == "div" || elt_type == "max")) {
continue;
}
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0, "relu");
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1, "relu");
}
......@@ -226,6 +236,9 @@ TEST(Elementwise, precision) {
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
......
......@@ -322,6 +322,10 @@ void TestPoolPaddings(Place place, float abs_error = 2e-5) {
{1, 1},
{0, 0, 1, 1},
{2, 2});
// Ascend restriction: padT should equals padB, and padL should equals padR
if (place == TARGET(kHuaweiAscendNPU)) {
continue;
}
TestPoolHelper(place,
abs_error,
{2, 3, 6, 7},
......@@ -381,6 +385,9 @@ TEST(Pool, precision) {
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册