提交 92686bae 编写于 作者: Y yejianwu

merge with origin master

......@@ -3,6 +3,7 @@ stages:
- pycodestyle
- platform_compitable_tests
- ops_test
- api_test
- ops_benchmark
- extra_tests
......@@ -21,7 +22,13 @@ ops_test:
stage: ops_test
script:
- if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi
- python tools/bazel_adb_run.py --target="//mace/ops:ops_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//mace/ops:ops_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
api_test:
stage: api_test
script:
- if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi
- python tools/bazel_adb_run.py --target="//mace/test:mace_api_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
ops_benchmark:
stage: ops_benchmark
......
......@@ -178,6 +178,9 @@ MaceStatus MaceEngine::Impl::Run(
std::vector<Tensor *> input_tensors;
std::vector<Tensor *> output_tensors;
for (auto &input : inputs) {
MACE_CHECK(input.second.shape().size() == 4,
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
Tensor *input_tensor =
ws_->GetTensor(MakeString("mace_input_node_", input.first, ":0"));
input_tensor->Resize(input.second.shape());
......@@ -190,6 +193,11 @@ MaceStatus MaceEngine::Impl::Run(
input_tensors.push_back(input_tensor);
}
for (auto &output : *outputs) {
if (device_type_ == DeviceType::OPENCL) {
MACE_CHECK(output.second.shape().size() == 4,
"The outputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
}
Tensor *output_tensor =
ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0"));
output_tensors.push_back(output_tensor);
......
......@@ -86,6 +86,7 @@ extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_CWise(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
......@@ -98,7 +99,9 @@ extern void Register_Pad(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_Quantize(OperatorRegistry *op_registry);
extern void Register_ReOrganize(OperatorRegistry *op_registry);
extern void Register_Requantize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
......@@ -124,6 +127,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_CWise(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
......@@ -136,6 +140,8 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
ops::Register_Quantize(this);
ops::Register_Requantize(this);
ops::Register_ReOrganize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
......
......@@ -108,12 +108,25 @@ class Operator : public OperatorBase {
inputs_.push_back(tensor);
}
for (const std::string &output_str : operator_def.output()) {
for (size_t i = 0; i < operator_def.output().size(); ++i) {
const std::string output_str = operator_def.output()[i];
if (ws->HasTensor(output_str)) {
outputs_.push_back(ws->GetTensor(output_str));
} else {
MACE_CHECK(
operator_def.output_type().size() == 0
|| operator_def.output().size() == operator_def.output_type().size(),
"operator output size != operator output type size",
operator_def.output().size(),
operator_def.output_type().size());
DataType output_type;
if (i < operator_def.output_type().size()) {
output_type = operator_def.output_type()[i];
} else {
output_type = DataTypeToEnum<T>::v();
}
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, GetDeviceAllocator(D), DataTypeToEnum<T>::v())));
output_str, GetDeviceAllocator(D), output_type)));
}
}
}
......
......@@ -81,15 +81,19 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
}
VLOG(3) << "Model data size: " << model_data_size;
if (type == DeviceType::CPU || type == DeviceType::NEON) {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(type), model_data_ptr, model_data_size));
} else {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(type), model_data_size));
tensor_buffer_->Map(nullptr);
tensor_buffer_->Copy(model_data_ptr, 0, model_data_size);
tensor_buffer_->UnMap();
if (model_data_size > 0) {
if (type == DeviceType::CPU || type == DeviceType::NEON) {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(type),
model_data_ptr,
model_data_size));
} else {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(type), model_data_size));
tensor_buffer_->Map(nullptr);
tensor_buffer_->Copy(model_data_ptr, 0, model_data_size);
tensor_buffer_->UnMap();
}
}
for (auto &const_tensor : net_def.tensors()) {
......
......@@ -163,6 +163,8 @@ bool RunModel(const std::vector<std::string> &input_names,
static_cast<GPUPriorityHint>(FLAGS_gpu_priority_hint));
}
// DO NOT USE tmp directory.
// please use APP's own directory
const std::string kernel_file_path =
"/data/local/tmp/mace_run/cl";
......
......@@ -28,9 +28,12 @@ cc_library(
"opencl/*.h",
"arm/*.h",
]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
......@@ -48,9 +51,12 @@ cc_test(
"opencl/*_test.cc",
],
),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
......
......@@ -362,14 +362,14 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
};
} else if (use_neon_1x1_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x1S1(input_data,
Conv2dNeonK1x1S1(pad_input,
filter_data,
batch,
height,
width,
extra_input_height,
extra_input_width,
input_channels,
channels,
output_data);
pad_output);
};
} else {
conv_func = [=](const float *pad_input, float *pad_output) {
......
......@@ -34,10 +34,10 @@ void FullyConnectedFunctor<DeviceType::NEON,
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
Gemv(weight_ptr, input_ptr, input_size, output_size, output_ptr);
for (int j = 0; j < output_size; ++j) {
output_ptr[j] += bias_ptr[j];
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
......
......@@ -566,6 +566,7 @@ inline void GemmTile(const float *A,
}
} // namespace
// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
const float *B,
const index_t batch,
......@@ -573,6 +574,12 @@ void Gemm(const float *A,
const index_t K,
const index_t width,
float *C) {
if (width == 1) {
for (index_t b = 0; b < batch; ++b) {
Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
}
return;
}
memset(C, 0, sizeof(float) * batch * height * width);
......@@ -628,6 +635,7 @@ void Gemm(const float *A,
} // n
}
// A: height x K, B: K x width, C: height x width
void GemmRef(const float *A,
const float *B,
const index_t height,
......@@ -647,19 +655,24 @@ void GemmRef(const float *A,
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
memset(out_ptr, 0, sizeof(float) * height);
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[h] += v_ptr[w] * m_ptr[h * width + w];
memset(out_ptr, 0, sizeof(float) * height * batch);
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w];
}
}
}
}
// M: height x width, Vin: width x 1, Vout: height x 1
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
......@@ -669,88 +682,90 @@ void Gemv(const float *m_ptr,
index_t remain_w = width - (width_d4 << 2);
index_t remain_h = height - (height_d4 << 2);
for (index_t b = 0; b < batch; ++b) {
#pragma omp parallel for
for (index_t h = 0; h < height_d4; ++h) {
const float *m_ptr0 = m_ptr + h * width * 4;
const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr;
float *out_ptr0 = out_ptr + h * 4;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv;
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f);
for (index_t w = 0; w < width_d4; ++w) {
vm0 = vld1q_f32(m_ptr0);
vm1 = vld1q_f32(m_ptr1);
vm2 = vld1q_f32(m_ptr2);
vm3 = vld1q_f32(m_ptr3);
vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm0, vv);
vsum1 = vmlaq_f32(vsum1, vm1, vv);
vsum2 = vmlaq_f32(vsum2, vm2, vv);
vsum3 = vmlaq_f32(vsum3, vm3, vv);
m_ptr0 += 4;
m_ptr1 += 4;
m_ptr2 += 4;
m_ptr3 += 4;
v_ptr0 += 4;
}
float sum0 = vaddvq_f32(vsum0);
float sum1 = vaddvq_f32(vsum1);
float sum2 = vaddvq_f32(vsum2);
float sum3 = vaddvq_f32(vsum3);
// handle remaining w
for (index_t w = 0; w < remain_w; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0];
sum1 += m_ptr1[0] * v_ptr0[0];
sum2 += m_ptr2[0] * v_ptr0[0];
sum3 += m_ptr3[0] * v_ptr0[0];
m_ptr0++;
m_ptr1++;
m_ptr2++;
m_ptr3++;
v_ptr0++;
for (index_t h = 0; h < height_d4; ++h) {
const float *m_ptr0 = m_ptr + h * width * 4;
const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr + b * width;
float *out_ptr0 = out_ptr + h * 4 + b * height;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv;
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f);
for (index_t w = 0; w < width_d4; ++w) {
vm0 = vld1q_f32(m_ptr0);
vm1 = vld1q_f32(m_ptr1);
vm2 = vld1q_f32(m_ptr2);
vm3 = vld1q_f32(m_ptr3);
vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm0, vv);
vsum1 = vmlaq_f32(vsum1, vm1, vv);
vsum2 = vmlaq_f32(vsum2, vm2, vv);
vsum3 = vmlaq_f32(vsum3, vm3, vv);
m_ptr0 += 4;
m_ptr1 += 4;
m_ptr2 += 4;
m_ptr3 += 4;
v_ptr0 += 4;
}
float sum0 = vaddvq_f32(vsum0);
float sum1 = vaddvq_f32(vsum1);
float sum2 = vaddvq_f32(vsum2);
float sum3 = vaddvq_f32(vsum3);
// handle remaining w
for (index_t w = 0; w < remain_w; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0];
sum1 += m_ptr1[0] * v_ptr0[0];
sum2 += m_ptr2[0] * v_ptr0[0];
sum3 += m_ptr3[0] * v_ptr0[0];
m_ptr0++;
m_ptr1++;
m_ptr2++;
m_ptr3++;
v_ptr0++;
}
*out_ptr0++ = sum0;
*out_ptr0++ = sum1;
*out_ptr0++ = sum2;
*out_ptr0++ = sum3;
}
*out_ptr0++ = sum0;
*out_ptr0++ = sum1;
*out_ptr0++ = sum2;
*out_ptr0++ = sum3;
}
// handle remaining h
index_t remain_start_height = height_d4 << 2;
// handle remaining h
index_t remain_start_height = height_d4 << 2;
#pragma omp parallel for
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr;
for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm, vv);
m_ptr0 += 4;
v_ptr0 += 4;
}
float sum = vaddvq_f32(vsum0);
for (index_t w = 0; w < remain_w; ++w) {
sum += m_ptr0[0] * v_ptr0[0];
m_ptr0++;
v_ptr0++;
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr;
for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm, vv);
m_ptr0 += 4;
v_ptr0 += 4;
}
float sum = vaddvq_f32(vsum0);
for (index_t w = 0; w < remain_w; ++w) {
sum += m_ptr0[0] * v_ptr0[0];
m_ptr0++;
v_ptr0++;
}
out_ptr[remain_start_height + h] = sum;
}
out_ptr[remain_start_height + h] = sum;
}
#else
GemvRef(m_ptr, v_ptr, width, height, out_ptr);
GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr);
#endif
}
......
......@@ -41,12 +41,14 @@ void GemmRef(const float *A,
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
......
......@@ -70,8 +70,8 @@ TEST(GEMMTest, gemv) {
[&gen, &nd] {
return nd(gen);
});
kernels::Gemv(A.get(), B.get(), K, N, C.get());
kernels::GemvRef(A.get(), B.get(), K, N, C_ref.get());
kernels::Gemv(A.get(), B.get(), 1, K, N, C.get());
kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get());
for (int i = 0; i < N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_QUANTIZE_H_
#define MACE_KERNELS_QUANTIZE_H_
#include <vector>
#include <algorithm>
#include <limits>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template<typename T>
inline void AdjustRange(const float in_min_data,
const float in_max_data,
float *out_min_data,
float *out_max_data) {
// re-range to make range include zero float and
// make zero float as integer u8
const float quantized_max = std::numeric_limits<uint8_t>::max();
float out_min = fminf(0.f, in_min_data);
float out_max = fmaxf(0.f, in_max_data);
if (out_min < 0.f) {
float stepsize = (in_max_data - in_min_data) / quantized_max;
float quantized_zero = -in_min_data / stepsize;
float quantized_zero_near_int = roundf(quantized_zero);
if (fabs(quantized_zero - quantized_zero_near_int) > 1e-6) {
if (quantized_zero < quantized_zero_near_int) {
// keep out_max fixed, and move out_min
stepsize = out_max / (quantized_max - quantized_zero_near_int);
out_min = out_max - quantized_max * stepsize;
} else {
// keep out_min fixed, and move out_max
stepsize = -out_min / quantized_zero_near_int;
out_max = out_min + quantized_max * stepsize;
}
}
}
*out_min_data = out_min;
*out_max_data = out_max;
}
template<typename T>
inline T Saturate(float value) {
int rounded_value = static_cast<int>(value);
if (rounded_value <= std::numeric_limits<T>::lowest()) {
return std::numeric_limits<T>::lowest();
} else if (rounded_value >= std::numeric_limits<T>::max()) {
return std::numeric_limits<T>::max();
} else {
return static_cast<T>(rounded_value);
}
}
template<DeviceType D, typename T>
struct QuantizeFunctor;
template<>
struct QuantizeFunctor<CPU, uint8_t> {
QuantizeFunctor() {}
void operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
Tensor *output,
Tensor *out_min,
Tensor *out_max,
StatsFuture *future) {
const float *input_data = input->data<float>();
const float in_min_data = in_min->data<float>()[0];
const float in_max_data = in_max->data<float>()[0];
uint8_t *output_data = output->mutable_data<uint8_t>();
float *out_min_data = out_min->mutable_data<float>();
float *out_max_data = out_max->mutable_data<float>();
AdjustRange<uint8_t>(in_min_data, in_max_data, out_min_data, out_max_data);
float recip_stepsize = 255.f / (out_max_data[0] - out_min_data[0]);
for (int i = 0; i < input->size(); ++i) {
output_data[i] = Saturate<uint8_t>(roundf(
(input_data[i] - in_min_data) * recip_stepsize));
}
}
};
template<DeviceType D, typename T>
struct DequantizeFunctor;
template<>
struct DequantizeFunctor<CPU, uint8_t> {
DequantizeFunctor() {}
void operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
Tensor *output,
StatsFuture *future) {
const uint8_t *input_data = input->data<uint8_t>();
const float in_min_data = in_min->data<float>()[0];
const float in_max_data = in_max->data<float>()[0];
float *output_data = output->mutable_data<float>();
float stepsize = (in_max_data - in_min_data) / 255.0;
for (int i = 0; i < input->size(); ++i) {
output_data[i] = in_min_data + stepsize * input_data[i];
}
}
};
template<DeviceType D, typename T>
struct RequantizeFunctor;
template<>
struct RequantizeFunctor<CPU, uint8_t> {
RequantizeFunctor() {}
void operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
const Tensor *rerange_min,
const Tensor *rerange_max,
Tensor *output,
Tensor *out_min,
Tensor *out_max,
StatsFuture *future) {
const int *input_data = input->data<int>();
const float in_min_data = in_min->data<float>()[0];
const float in_max_data = in_max->data<float>()[0];
float rerange_min_data;
float rerange_max_data;
int min_val = std::numeric_limits<int>::max();
int max_val = std::numeric_limits<int>::lowest();
double
si = (in_max_data - in_min_data) / std::numeric_limits<uint32_t>::max();
if (rerange_min == nullptr && rerange_max == nullptr) {
for (int i = 0; i < input->size(); ++i) {
min_val = std::min(min_val, input_data[i]);
max_val = std::max(max_val, input_data[i]);
}
rerange_min_data = min_val * si;
rerange_max_data = max_val * si;
} else {
rerange_min_data = rerange_min->data<float>()[0];
rerange_max_data = rerange_max->data<float>()[0];
}
uint8_t *output_data = output->mutable_data<uint8_t>();
float *out_min_data = out_min->mutable_data<float>();
float *out_max_data = out_max->mutable_data<float>();
AdjustRange<uint8_t>(rerange_min_data,
rerange_max_data,
out_min_data,
out_max_data);
/**
* f = qi * si = min_o + qo * so
* => qo = (qi * si - min_o) / so
* = qi * (si/so) - min_o / so
* = qi * (si / so) + zo
*
* zo = -min_o / so
*
*/
float so =
(out_max_data[0] - out_min_data[0]) / std::numeric_limits<uint8_t>::max();
double step_ratio = si / so;
float quantized_out_zero = -out_min_data[0] / so;
for (int i = 0; i < output->size(); ++i) {
output_data[i] =
Saturate<uint8_t>(roundf(
quantized_out_zero + input_data[i] * step_ratio));
}
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_QUANTIZE_H_
......@@ -37,31 +37,44 @@ struct TransposeFunctor {
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
if (input->dim_size() == 2) {
MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform");
index_t stride_i = input_shape[0];
index_t stride_j = input_shape[1];
for (int i = 0; i < input_shape[0]; ++i) {
for (int j = 0; j < input_shape[1]; ++j) {
output_data[j * stride_i + i] = input_data[i * stride_j + j];
}
}
} else if (input->dim_size() == 4) {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
}
......
......@@ -34,9 +34,12 @@ cc_library(
["*.h"],
exclude = ["ops_test_util.h"],
),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
deps = [
"//mace/kernels",
],
......@@ -49,9 +52,12 @@ cc_test(
srcs = glob(
["*_test.cc"],
),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
......@@ -65,9 +71,12 @@ cc_test(
name = "ops_benchmark",
testonly = 1,
srcs = glob(["*_benchmark.cc"]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
......
......@@ -429,7 +429,7 @@ TEST_F(BatchNormOpTest, NEONTest) {
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5);
1e-5, 1e-4);
}
} // namespace test
......
......@@ -826,7 +826,7 @@ static void TestNeonArbitraryPadConvNxN(const std::vector<index_t> &shape,
for (int kernel_size : {1, 3, 5}) {
for (int stride : {1, 2}) {
if (stride < kernel_size) {
if (stride <= kernel_size) {
func(kernel_size, kernel_size, stride, stride);
}
}
......
......@@ -337,6 +337,8 @@ TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON(1, 7, 7, 32, 16);
FullyConnectedTestNEON(1, 7, 7, 512, 128);
FullyConnectedTestNEON(1, 1, 1, 2048, 1024);
FullyConnectedTestNEON(3, 1, 1, 16, 8);
FullyConnectedTestNEON(3, 7, 7, 32, 16);
}
} // namespace test
......
......@@ -375,90 +375,92 @@ TEST_F(FusedConv2dOpTest, OPENCLUnalignedConvNxNS12) {
namespace {
template<DeviceType D>
void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape,
const int kernel, const int stride,
Padding type) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
// generate random input
static unsigned int seed = time(NULL);
index_t batch = 3 + (rand_r(&seed) % 10);
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2] + (rand_r(&seed) % 10);
index_t output_channels = shape[3] + (rand_r(&seed) % 10);
// Construct graph
OpsTestNet net;
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
// generate random input
srand(time(NULL));
index_t batch = 3;
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2];
index_t output_channels = shape[3];
// Construct graph
OpsTestNet net;
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
std::vector<float> float_input_data;
GenerateRandomRealTypeData({batch, height, width, input_channels},
&float_input_data);
std::vector<float> float_filter_data;
GenerateRandomRealTypeData(
{kernel_h, kernel_w, output_channels, input_channels},
std::vector<float> float_input_data;
GenerateRandomRealTypeData({batch, height, width, input_channels},
&float_input_data);
std::vector<float> float_filter_data;
GenerateRandomRealTypeData(
{kernel, kernel, output_channels, input_channels},
&float_filter_data);
std::vector<float> float_bias_data;
GenerateRandomRealTypeData({output_channels}, &float_bias_data);
// Add input data
net.AddInputFromArray<D, float>(
std::vector<float> float_bias_data;
GenerateRandomRealTypeData({output_channels}, &float_bias_data);
// Add input data
net.AddInputFromArray<D, float>(
"Input", {batch, height, width, input_channels}, float_input_data);
net.AddInputFromArray<D, float>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels},
net.AddInputFromArray<D, float>(
"Filter", {kernel, kernel, output_channels, input_channels},
float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
// run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
// run on gpu
BufferToImage<D, half>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, half>(&net, "Filter", "FilterImage",
kernels::BufferType::CONV2D_FILTER);
BufferToImage<D, half>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
// run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// run on gpu
BufferToImage<D, half>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, half>(&net, "Filter", "FilterImage",
kernels::BufferType::CONV2D_FILTER);
BufferToImage<D, half>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(D);
// Run on device
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
1e-2, 1e-1);
};
ImageToBuffer<D, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
for (int kernel_size : {1, 3}) {
for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
}
}
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
1e-2, 1e-1);
}
} // namespace
TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConvNxNS12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64});
TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConv1x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64}, 1, 1, VALID);
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({31, 37, 31, 37}, 1, 1, SAME);
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64}, 1, 2, VALID);
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({31, 37, 31, 37}, 1, 2, SAME);
}
TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConv3x3S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64}, 3, 1, VALID);
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({31, 37, 31, 37}, 3, 1, SAME);
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64}, 3, 2, VALID);
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({31, 37, 31, 37}, 3, 2, SAME);
}
namespace {
......
......@@ -52,6 +52,11 @@ class OpDefBuilder {
return *this;
}
OpDefBuilder &OutputType(const std::vector<DataType> &output_type) {
op_def_.set_output_type(output_type);
return *this;
}
OpDefBuilder AddIntArg(const std::string &name, const int value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
......@@ -283,6 +288,16 @@ class OpsTestNet {
return RunOp(DeviceType::CPU);
}
bool RunNet(const NetDef &net_def, const DeviceType device) {
device_ = device;
net_ = CreateNet(op_registry_, net_def, &ws_, device, NetMode::INIT);
if (!net_->Run()) {
return false;
}
net_ = CreateNet(op_registry_, net_def, &ws_, device);
return net_->Run();
}
Tensor *GetOutput(const char *output_name) {
return ws_.GetTensor(output_name);
}
......@@ -451,7 +466,7 @@ struct Expector<EXP_TYPE, RES_TYPE, true> {
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a(i), b(i));
ExpectEqual(a[i], b[i]);
}
}
......@@ -489,12 +504,35 @@ struct Expector<EXP_TYPE, RES_TYPE, true> {
}
};
template<typename EXP_TYPE, typename RES_TYPE>
struct Expector<EXP_TYPE, RES_TYPE, false> {
static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a[i], b[i]);
}
}
static void Near(const Tensor &x, const Tensor &y,
const double rel_err,
const double abs_err) {
Equal(x, y);
}
};
template<typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y,
const double rel_err = 1e-5,
const double abs_err = 1e-8) {
static_assert(is_floating_point_type<T>::value,
"T is not a floating point type");
Expector<T, T>::Near(x, y, rel_err, abs_err);
}
......@@ -502,9 +540,6 @@ template<typename EXP_TYPE, typename RES_TYPE>
void ExpectTensorNear(const Tensor &x, const Tensor &y,
const double rel_err = 1e-5,
const double abs_err = 1e-8) {
static_assert(is_floating_point_type<EXP_TYPE>::value &&
is_floating_point_type<RES_TYPE>::value,
"T is not a floating point type");
Expector<EXP_TYPE, RES_TYPE>::Near(x, y, rel_err, abs_err);
}
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/quantize.h"
namespace mace {
namespace ops {
void Register_Quantize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
QuantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
QuantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Dequantize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
DequantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
DequantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Requantize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
RequantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
RequantizeOp<DeviceType::CPU, uint8_t>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_QUANTIZE_H_
#define MACE_OPS_QUANTIZE_H_
#include "mace/core/operator.h"
#include "mace/kernels/quantize.h"
namespace mace {
namespace ops {
template<DeviceType D, class T>
class QuantizeOp : public Operator<D, T> {
public:
QuantizeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value");
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
Tensor *output = this->Output(OUTPUT);
Tensor *out_min = this->Output(OUT_MIN);
Tensor *out_max = this->Output(OUT_MAX);
output->ResizeLike(input);
out_min->ResizeLike(in_min);
out_max->ResizeLike(in_max);
functor_(input, in_min, in_max, output, out_min, out_max, future);
return true;
}
private:
kernels::QuantizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX);
OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX);
};
template<DeviceType D, class T>
class DequantizeOp : public Operator<D, T> {
public:
DequantizeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value");
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, in_min, in_max, output, future);
return true;
}
private:
kernels::DequantizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX);
OP_OUTPUT_TAGS(OUTPUT);
};
template<DeviceType D, class T>
class RequantizeOp : public Operator<D, T> {
public:
RequantizeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
const Tensor *rerange_min = nullptr;
const Tensor *rerange_max = nullptr;
MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value");
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
if (this->InputSize() >= 5) {
rerange_min = this->Input(RERANGE_MIN);
rerange_max = this->Input(RERANGE_MAX);
MACE_CHECK(rerange_min->size() == 1,
"rerange min val tensor has more than 1 value");
MACE_CHECK(rerange_max->size() == 1,
"rerange max val tensor has more than 1 value");
}
Tensor *output = this->Output(OUTPUT);
Tensor *out_min = this->Output(OUT_MIN);
Tensor *out_max = this->Output(OUT_MAX);
output->ResizeLike(input);
out_min->ResizeLike(in_min);
out_max->ResizeLike(out_max);
functor_(input,
in_min,
in_max,
rerange_min,
rerange_max,
output,
out_min,
out_max,
future);
return true;
}
private:
kernels::RequantizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX, RERANGE_MIN, RERANGE_MAX);
OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_QUANTIZE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class QuantizeTest : public OpsTestBase {};
TEST_F(QuantizeTest, TestQuantize) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, float>("Input", {1, 2, 3, 1}, {
-2, -1, 1, 2, 3, 4
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
OpDefBuilder("Quantize", "QuantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto output_min = net.GetTensor("OutputMin");
auto output_max = net.GetTensor("OutputMax");
auto expected_output = CreateTensor<uint8_t>({1, 2, 3, 1},
{
32, 64, 127, 159, 191, 223
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<uint8_t>(*expected_output, *output);
ExpectTensorNear<float>(*expected_min, *output_min);
ExpectTensorNear<float>(*expected_max, *output_max);
}
TEST_F(QuantizeTest, TestQuantizeTrend) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<CPU, float>("Input", {100});
const float *input_data = net.GetTensor("Input")->data<float>();
net.AddInputFromArray<CPU, float>("InputMin",
{1},
{*std::min_element(input_data,
input_data
+ net.GetTensor("Input")->size())});
net.AddInputFromArray<CPU, float>("InputMax",
{1},
{*std::max_element(input_data,
input_data
+ net.GetTensor("Input")->size())});
OpDefBuilder("Quantize", "QuantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto output_min = net.GetTensor("OutputMin");
auto output_max = net.GetTensor("OutputMax");
const uint8_t *output_data = net.GetTensor("Output")->data<uint8_t>();
for (int i = 1; i < output->size(); ++i) {
if (input_data[i] > input_data[i - 1]) {
EXPECT_GE(output_data[i], output_data[i - 1]);
} else if (input_data[i] == input_data[i - 1]) {
EXPECT_EQ(output_data[i], output_data[i - 1]);
} else {
EXPECT_LE(output_data[i], output_data[i - 1]);
}
}
}
TEST_F(QuantizeTest, TestDequantize) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, uint8_t>("Input", {1, 2, 3, 1}, {
32, 64, 127, 159, 191, 223
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3.01887});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
OpDefBuilder("Dequantize", "DequantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.OutputType({DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected_output = CreateTensor<float>({1, 2, 3, 1},
{
-2, -1, 1, 2, 3, 4
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<float>(*expected_output, *output, 0.1, 0.01);
}
TEST_F(QuantizeTest, TestRequantizeWithMinMax) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, int>("Input", {1, 2, 3, 1}, {
-1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
net.AddInputFromArray<CPU, float>("RerangeMin", {1}, {-3.01887});
net.AddInputFromArray<CPU, float>("RerangeMax", {1}, {5});
OpDefBuilder("Requantize", "RequantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Input("RerangeMin")
.Input("RerangeMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected_output = CreateTensor<uint8_t>({1, 2, 3, 1},
{
32, 64, 128, 160, 191, 223
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<uint8_t>(*expected_output, *output);
}
TEST_F(QuantizeTest, TestRequantizeWithoutMinMax) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, int>("Input", {1, 2, 3, 1}, {
-1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
OpDefBuilder("Requantize", "RequantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected_output = CreateTensor<uint8_t>({1, 2, 3, 1},
{
0, 43, 128, 170, 213, 255
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<uint8_t>(*expected_output, *output);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -28,16 +28,16 @@ class TransposeOp : public Operator<D, T> {
public:
TransposeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
dims_(OperatorBase::GetRepeatedArgument<int>(
"dims")),
dims_(OperatorBase::GetRepeatedArgument<int>("dims")),
functor_(dims_) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
const std::vector<index_t> &input_shape = input->shape();
MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4,
"rank should be 4");
MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4
|| input_shape.size() == 2 && dims_.size() == 2,
"rank should be 2 or 4");
std::vector<index_t> output_shape;
for (int i = 0; i < dims_.size(); ++i) {
output_shape.push_back(input_shape[dims_[i]]);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template<DeviceType D, typename T>
void TransposeBenchmark(int iters,
std::vector<index_t> shape,
std::vector<int> dims) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", shape);
OpDefBuilder("Transpose", "TransposeBM")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", dims)
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define BM_TRANSPOSE2D_MACRO(H, W, TYPE, DEVICE) \
static void BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TransposeBenchmark<DEVICE, TYPE>(iters, {H, W}, {1, 0}); \
} \
BENCHMARK(BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE)
#define BM_TRANSPOSE2D(H, W) \
BM_TRANSPOSE2D_MACRO(H, W, float, CPU);
#define BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, TYPE, DEVICE) \
static void \
BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TransposeBenchmark<DEVICE, TYPE>(iters, {N, C, H, W}, {D0, D1, D2, D3}); \
} \
BENCHMARK( \
BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE)
#define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \
BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU);
BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2);
BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1);
BM_TRANSPOSE2D(128, 128);
BM_TRANSPOSE2D(512, 512);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -49,6 +49,29 @@ TEST_F(TransposeOpTest, NCHW) {
TransposeNCHWTest({1, 64, 48, 128});
}
TEST_F(TransposeOpTest, Rank2) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, float>("Input", {2, 3}, {1, 2, 3, 4, 5, 6});
OpDefBuilder("Transpose", "TransposeNCHWTest")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", {1, 0})
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput",
{3, 2},
{1, 4, 2, 5, 3, 6});
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -320,6 +320,13 @@ class CaffeConverter(object):
arg.name = 'T'
arg.i = self.dt
input_op = self.ops_map[name]
if input_op.layer is not None:
output_shape = input_op.output_shape_map[input_op.layer.top[0]]
else:
output_shape = input_op.output_shape_map[input_op.name]
self.add_output_shape(op_def, output_shape)
def add_output_transform(self, names):
for name in names:
output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0"
......@@ -1091,15 +1098,15 @@ class CaffeConverter(object):
dims_arg.ints.extend([0, 2, 3, 1]) # NCHW -> NHWC
def convert(self, input_nodes, input_shapes, output_nodes):
assert self.ops[0].type == 'Input'
self.add_input_op_shape(input_nodes, input_shapes)
if self.device == 'gpu':
self.add_input_transform(input_nodes)
if self.device == 'neon':
self.add_neon_input_transform(input_nodes)
assert self.ops[0].type == 'Input'
self.add_input_op_shape(input_nodes, input_shapes)
for op in self.ops:
if op.name in self.resolved_ops:
continue
......
......@@ -46,7 +46,11 @@ class MemoryOptimizer(object):
self.ref_counter[tensor_name] = 0
def is_buffer_image_op(self, op):
return op.type == 'BufferToImage' or op.type == 'ImageToBuffer'
if op.type == 'BufferToImage':
for arg in op.arg:
if arg.name == 'mode' and arg.i == 0:
return True
return op.type == 'ImageToBuffer'
def get_mem_size(self, op_type, output_shape):
mem_size = [0, 0]
......
......@@ -155,6 +155,8 @@ class TFConverter(object):
arg.name = 'T'
arg.i = self.dt
self.add_output_shape(self.ops[name].outputs, op_def)
def add_neon_input_transform(self, names):
for name in names:
new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0"
......
# Description:
# Mace operators.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7", "if_hexagon_enabled")
cc_test(
name = "mace_api_test",
testonly = 1,
srcs = ["mace_api_test.cc"],
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"@gtest//:gtest_main",
],
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <fstream>
#include "mace/core/operator.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace test {
class MaceAPITest : public ::testing::Test {};
namespace {
void GenerateInputs(const std::vector<std::string> &input_names,
const std::vector<int64_t> &input_shape,
std::map<std::string, mace::MaceTensor> *inputs) {
size_t input_size = input_names.size();
for (size_t i = 0; i < input_size; ++i) {
// Allocate input and output
int64_t input_size =
std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int64_t>());
auto buffer_in = std::shared_ptr<float>(new float[input_size],
std::default_delete<float[]>());
// load input
std::vector<float> input_data;
ops::test::GenerateRandomRealTypeData(input_shape, &input_data);
memcpy(buffer_in.get(), input_data.data(), input_size * sizeof(float));
(*inputs)[input_names[i]] = mace::MaceTensor(input_shape, buffer_in);
}
}
void GenerateOutputs(const std::vector<std::string> &output_names,
const std::vector<int64_t> &output_shape,
std::map<std::string, mace::MaceTensor> *outputs) {
size_t output_size = output_names.size();
for (size_t i = 0; i < output_size; ++i) {
int64_t output_size =
std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>());
(*outputs)[output_names[i]] = mace::MaceTensor(output_shape, buffer_out);
}
}
template <typename T>
void BufferToImage(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
const std::vector<int> &mem_ids,
NetDef *net_def,
const int mode = NetMode::NORMAL) {
OperatorDef operator_def;
ops::test::OpDefBuilder("BufferToImage", "BufferToImageOp")
.Input(input_name)
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("mode", mode)
.Finalize(&operator_def);
operator_def.set_mem_id(mem_ids);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void ImageToBuffer(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("ImageToBuffer", "ImageToBufferOp")
.Input(input_name)
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void Conv3x3(const std::string &input_name,
const std::string &filter_name,
const std::string &output_name,
const std::vector<int> &mem_ids,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Conv2D", "Conv2dOp")
.Input(input_name)
.Input(filter_name)
.Output(output_name)
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(&operator_def);
operator_def.set_mem_id(mem_ids);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void Relu(const std::string &input_name,
const std::string &output_name,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Activation", "ReluTest")
.Input(input_name)
.Output(output_name)
.AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void AddTensor(const std::string &name,
const std::vector<int64_t> &shape,
T *data,
NetDef *net_def) {
ConstTensor tensor(name,
reinterpret_cast<unsigned char *>(data),
shape,
DataTypeToEnum<T>::value);
net_def->mutable_tensors().push_back(tensor);
}
template <DeviceType D, typename T>
void CheckOutputs(const NetDef &net_def,
const std::map<std::string, mace::MaceTensor> &inputs,
const std::map<std::string, mace::MaceTensor> &outputs) {
ops::test::OpsTestNet net;
for (auto input : inputs) {
auto input_shape = input.second.shape();
const int64_t data_size = std::accumulate(input_shape.begin(),
input_shape.end(), 1,
std::multiplies<int64_t>());
std::vector<float> input_data(data_size);
memcpy(input_data.data(), input.second.data().get(),
data_size * sizeof(float));
std::string input_name = MakeString("mace_input_node_",
input.first, ":0");
net.AddInputFromArray<D, float>(input_name, input.second.shape(),
input_data);
}
auto tensors = net_def.tensors();
for (auto tensor : tensors) {
auto shape = tensor.dims();
const int64_t data_size = std::accumulate(shape.begin(),
shape.end(), 1,
std::multiplies<int64_t>());
std::vector<T> data(data_size);
memcpy(data.data(), reinterpret_cast<const T *>(tensor.data()),
data_size * sizeof(T));
net.AddInputFromArray<D, T>(tensor.name(), shape, data);
}
net.RunNet(net_def, D);
for (auto output : outputs) {
std::unique_ptr<Tensor> tmp_tensor(
new Tensor(GetDeviceAllocator(DeviceType::CPU),
DataTypeToEnum<float>::v()));
auto output_shape = output.second.shape();
const int64_t data_size = std::accumulate(output_shape.begin(),
output_shape.end(), 1,
std::multiplies<float>());
tmp_tensor->Resize(output.second.shape());
float *data = tmp_tensor->mutable_data<float>();
memcpy(data, output.second.data().get(), data_size * sizeof(float));
std::string output_name = MakeString("mace_output_node_",
output.first, ":0");
ops::test::ExpectTensorNear<float>(*tmp_tensor,
*net.GetOutput(output_name.data()),
1e-5);
}
}
std::map<std::string, int> AddMemoryOptimization(
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes,
NetDef *net_def) {
std::map<std::string, int> res;
int mem_id = 0;
size_t input_shape_size = input_shapes.size();
uint32_t in_mem_block_x = 0;
uint32_t in_mem_block_y = 0;
for (size_t i = 0; i < input_shape_size; ++i) {
in_mem_block_x = std::max<uint32_t>(in_mem_block_x,
input_shapes[i][2] *
RoundUpDiv4(input_shapes[i][3]));
in_mem_block_y = std::max<uint32_t>(in_mem_block_y,
input_shapes[i][0] *
input_shapes[i][1]);
}
size_t input_size = input_names.size();
for (size_t i = 0; i < input_size; ++i) {
net_def->mutable_mem_arena().mutable_mem_block().push_back(
MemoryBlock(mem_id, in_mem_block_x, in_mem_block_y));
res[input_names[i]] = mem_id;
mem_id++;
}
size_t output_shape_size = output_shapes.size();
uint32_t out_mem_block_x = 0;
uint32_t out_mem_block_y = 0;
for (size_t i = 0; i < output_shape_size; ++i) {
out_mem_block_x = std::max<uint32_t>(out_mem_block_x,
output_shapes[i][2] *
RoundUpDiv4(output_shapes[i][3]));
out_mem_block_y = std::max<uint32_t>(out_mem_block_y,
output_shapes[i][0] *
output_shapes[i][1]);
}
size_t output_size = output_names.size();
for (size_t i = 0; i < output_size; ++i) {
net_def->mutable_mem_arena().mutable_mem_block().push_back(
MemoryBlock(mem_id, out_mem_block_x, out_mem_block_y));
res[output_names[i]] = mem_id;
mem_id++;
}
return res;
}
// The height and width of input and output must be equal.
template <typename T>
void MaceRun(const int in_out_size,
const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes,
const std::vector<int64_t> &filter_shape) {
std::vector<std::string> input_names;
std::vector<std::string> output_names;
for (int i = 0; i < in_out_size; ++i) {
input_names.push_back(MakeString("input", i));
output_names.push_back(MakeString("output", i));
}
std::string filter_tensor_name = "filter";
std::string filter_tensor_img_name = filter_tensor_name + "_image";
const DeviceType device = DeviceType::OPENCL;
NetDef net_def;
// Add memory optimization
auto mem_map = AddMemoryOptimization(input_names, output_names,
input_shapes, output_shapes,
&net_def);
std::vector<T> data;
ops::test::GenerateRandomRealTypeData<T>(filter_shape, &data);
AddTensor<T>(filter_tensor_name, filter_shape, data.data(), &net_def);
for (size_t i = 0; i < input_names.size(); ++i) {
std::string input_name = MakeString("mace_input_node_",
input_names[i], ":0");
BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]},
&net_def);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {},
&net_def, NetMode::INIT);
for (size_t i = 0; i < output_names.size(); ++i) {
Conv3x3<half>(input_names[i], filter_tensor_img_name,
output_names[i], {mem_map[output_names[i]]},
&net_def);
}
for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_",
output_names[i], ":0");
ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def);
}
MaceEngine engine(&net_def, device, input_names, output_names);
std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs;
for (int i = 0; i < 5; ++i) {
size_t input_shape_size = input_shapes.size();
for (size_t j = 0; j < input_shape_size; ++j) {
inputs.clear();
outputs.clear();
GenerateInputs(input_names, input_shapes[j], &inputs);
GenerateOutputs(output_names, output_shapes[j], &outputs);
engine.Run(inputs, &outputs);
}
}
CheckOutputs<DeviceType::OPENCL, T>(net_def, inputs, outputs);
}
} // namespace
TEST_F(MaceAPITest, GPUSingleInputOutput) {
MaceRun<float>(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16});
MaceRun<half>(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16});
}
TEST_F(MaceAPITest, GPUMultipleInputOutput) {
MaceRun<float>(2,
{{1, 16, 32, 16}},
{{1, 16, 32, 16}},
{3, 3, 16, 16});
MaceRun<half>(2,
{{1, 16, 32, 16}},
{{1, 16, 32, 16}},
{3, 3, 16, 16});
}
TEST_F(MaceAPITest, GPUVariableInputShape) {
MaceRun<float>(1,
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
MaceRun<float>(2,
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
}
} // namespace test
} // namespace mace
......@@ -94,8 +94,8 @@ class Tuner {
Tuner &operator=(const Tuner &) = delete;
inline void WriteRunParameters() {
VLOG(3) << "Write tuning result to " << path_;
if (path_ != nullptr) {
VLOG(3) << "Write tuning result to " << path_;
std::ofstream ofs(path_, std::ios::binary | std::ios::out);
if (ofs.is_open()) {
int64_t num_pramas = param_table_.size();
......
......@@ -42,7 +42,7 @@ def load_data(file):
return np.empty([0])
def format_output_name(name):
def format_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
......@@ -87,7 +87,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
input_dict = {}
for i in range(len(input_names)):
input_value = load_data(
input_file + "_" + input_names[i])
input_file + "_" + format_name(input_names[i]))
input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name(
input_names[i] + ':0')
......@@ -100,7 +100,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
output_values = session.run(output_nodes, feed_dict=input_dict)
for i in range(len(output_names)):
output_file_name = mace_out_file + "_" + \
format_output_name(output_names[i])
format_name(output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i],
mace_out_value, output_values[i])
......@@ -123,7 +123,7 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
net = caffe.Net(model_file, caffe.TEST, weights=weight_file)
for i in range(len(input_names)):
input_value = load_data(input_file + "_" + input_names[i])
input_value = load_data(input_file + "_" + format_name(input_names[i]))
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
2))
input_blob_name = input_names[i]
......@@ -142,7 +142,7 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[
1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = mace_out_file + "_" + format_output_name(
output_file_name = mace_out_file + "_" + format_name(
output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], mace_out_value,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册