提交 11f20df4 编写于 作者: 李滨

Merge branch 'add_op' into 'master'

feature: support RELU6/ArgMax/ResizeNearestNeighbor GroupNorm op for Caffe

See merge request deep-computing/mace!1276
......@@ -38,6 +38,10 @@ ProtoArgHelper::ProtoArgHelper(const NetDef &netdef) {
}
}
bool ProtoArgHelper::ExistArg(const std::string &arg_name) const {
return (arg_map_.count(arg_name) > 0);
}
namespace {
template <typename InputType, typename TargetType>
inline bool IsCastLossless(const InputType &value) {
......
......@@ -41,6 +41,11 @@ class ProtoArgHelper {
return ProtoArgHelper(def).GetRepeatedArgs<T>(arg_name, default_value);
}
template <typename Def>
static bool ExistArg(const Def &def, const std::string &arg_name) {
return ProtoArgHelper(def).ExistArg(arg_name);
}
explicit ProtoArgHelper(const OperatorDef &def);
explicit ProtoArgHelper(const NetDef &netdef);
......@@ -55,6 +60,8 @@ class ProtoArgHelper {
template <typename T>
std::vector<T> GetRepeatedArgs(const std::string &arg_name) const;
bool ExistArg(const std::string &arg_name) const;
private:
std::map<std::string, Argument> arg_map_;
};
......
......@@ -64,6 +64,11 @@ class Operation {
*operator_def_, name);
}
bool ExistArg(const std::string &name) const {
MACE_CHECK(operator_def_, "operator_def is null!");
return ProtoArgHelper::ExistArg<OperatorDef>(*operator_def_, name);
}
DeviceType device_type() const {
return static_cast<DeviceType>(operator_def_->device_type());
}
......
......@@ -893,7 +893,7 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
MACE_CHECK(output_size <= output->second.impl_->buffer_size)
<< "Output size exceeds buffer size: shape"
<< output_tensor->name() << " Output size exceeds buffer size: shape"
<< MakeString<int64_t>(shape) << " vs buffer size "
<< output->second.impl_->buffer_size;
output->second.impl_->shape = shape;
......
......@@ -24,24 +24,101 @@
namespace mace {
namespace ops {
template <DeviceType D, class T>
template<DeviceType D, class T>
class ArgMaxOp : public Operation {
public:
explicit ArgMaxOp(OpConstructContext *context)
: Operation(context),
model_type_(static_cast<FrameworkType>(Operation::GetOptionalArg<int>(
"framework_type", FrameworkType::TENSORFLOW))),
has_axis_(model_type_ != FrameworkType::CAFFE ||
Operation::ExistArg("axis")),
top_k_(Operation::GetOptionalArg<int>("top_k", 1)),
out_val_(Operation::GetOptionalArg<bool>("out_val", false)),
axis_(Operation::GetOptionalArg<int>("axis", 0)),
keep_dims_(Operation::GetOptionalArg<bool>("keepdims", true)),
argmin_(Operation::GetOptionalArg<bool>("argmin", false)) {}
argmin_(Operation::GetOptionalArg<bool>("argmin", false)),
keep_dims_(Operation::GetOptionalArg<bool>("keepdims", true)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *axis = this->InputSize() == 2 ?
this->Input(1) : nullptr;
Tensor *output = this->Output(0);
MACE_CHECK(keep_dims_, "Mace only supports keep_dims ArgMax.");
MACE_CHECK(input->dim_size() > 0, "ArgMax input should not be a scalar");
const auto input_dim_size = input->dim_size();
MACE_CHECK(input_dim_size > 0, "ArgMax input should not be a scalar");
const auto axis_value = GetAxisValue(input_dim_size);
MACE_RETURN_IF_ERROR(ResizeOutputTensor(output, input, axis_value));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<T>();
int axis_dim = 0;
int axis_dist = 0;
const auto &input_shape = input->shape();
if (axis_value != 0) {
axis_dim = input->dim(axis_value);
axis_dist = std::accumulate(input_shape.begin() + axis_value,
input_shape.end(),
1, std::multiplies<int>()) / axis_dim;
} else {
axis_dim = input->dim(0);
axis_dist = 1;
}
const auto output_loop = input->size() / axis_dim;
for (int i = 0; i < output_loop; i += 1) {
std::vector<std::pair<T, int>> input_data_vector(axis_dim);
const auto axis_base = i / axis_dist * axis_dim;
const auto axis_offset = i % axis_dist;
for (int d = 0; d < axis_dim; ++d) {
const auto input_idx = (axis_base + d) * axis_dist + axis_offset;
input_data_vector[d] = std::make_pair(input_data[input_idx], d);
}
if (argmin_) {
std::partial_sort(input_data_vector.begin(),
input_data_vector.begin() + top_k_,
input_data_vector.end(),
std::less<std::pair<T, int>>());
} else {
std::partial_sort(input_data_vector.begin(),
input_data_vector.begin() + top_k_,
input_data_vector.end(),
std::greater<std::pair<T, int>>());
}
if (!out_val_) {
auto output_data = output->mutable_data<int32_t>();
const auto top_k_base = i / axis_dist * top_k_;
for (int j = 0; j < top_k_; ++j) {
const auto output_idx = (top_k_base + j) * axis_dist + axis_offset;
output_data[output_idx] = input_data_vector[j].second;
}
} else if (has_axis_) { // Produces max/min value per axis
auto output_data = output->mutable_data<T>();
const auto top_k_base = i / axis_dist * top_k_;
for (int j = 0; j < top_k_; ++j) {
auto output_idx = (top_k_base + j) * axis_dist + axis_offset;
output_data[output_idx] = input_data_vector[j].first;
}
} else { // Produces max_ind and max/min value
auto output_data = output->mutable_data<T>();
const auto top_k_base_pos = 2 * i * top_k_;
const auto top_k_base_value = top_k_base_pos + top_k_;
for (int j = 0; j < top_k_; ++j) {
output_data[top_k_base_pos + j] = input_data_vector[j].second;
output_data[top_k_base_value + j] = input_data_vector[j].first;
}
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int GetAxisValue(const index_t input_dim_size) {
const Tensor *axis = this->InputSize() == 2 ? this->Input(1) : nullptr;
int axis_value = 0;
if (axis != nullptr) {
MACE_CHECK(axis->dim_size() == 0,
......@@ -52,65 +129,63 @@ class ArgMaxOp : public Operation {
axis_value = axis_;
}
if (axis_value < 0) {
axis_value += input->dim_size();
axis_value += input_dim_size;
}
MACE_CHECK(axis_value == input->dim_size() - 1,
"Mace argmax only supports last dimension as axis");
std::vector<index_t> output_shape(input->dim_size() - 1);
for (index_t d = 0; d < input->dim_size() - 1; ++d) {
output_shape[d] = input->dim(d < axis_value ? d : d + 1);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
return axis_value;
}
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<T>();
auto output_data = output->mutable_data<int32_t>();
index_t outer_size = output->size();
index_t inner_size = input->dim(axis_value);
if (argmin_) {
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
float min_value = std::numeric_limits<float>::max();
const T *input_ptr = input_data + i * inner_size;
for (index_t j = 0; j < inner_size; ++j) {
float input_value = input_ptr[j];
if (input_value < min_value) {
min_value = input_value;
idx = j;
}
}
output_data[i] = idx;
MaceStatus ResizeOutputTensor(Tensor *output, const Tensor *input,
const index_t axis_value) {
auto &input_shape = input->shape();
std::vector<index_t> output_shape;
if (model_type_ == FrameworkType::CAFFE) {
auto output_dim_num = input_shape.size();
if (output_dim_num < 3) {
output_dim_num = 3;
}
} else {
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
float max_value = std::numeric_limits<float>::lowest();
const T *input_ptr = input_data + i * inner_size;
for (index_t j = 0; j < inner_size; ++j) {
float input_value = input_ptr[j];
if (input_value > max_value) {
max_value = input_value;
idx = j;
}
output_shape.assign(output_dim_num, 1);
if (has_axis_) {
// Produces max/min idx or max/min value per axis
output_shape.assign(input_shape.begin(), input_shape.end());
output_shape[axis_value] = top_k_;
} else {
output_shape[0] = input_shape[0];
// Produces max_ind
output_shape[2] = top_k_;
if (out_val_) {
// Produces max/min idx and max/min value
output_shape[1] = 2;
}
output_data[i] = idx;
}
} else { // for Tensorflow and ONNX
output_shape.assign(input_shape.begin(),
input_shape.begin() + axis_value);
if (keep_dims_) {
output_shape.push_back(1);
}
for (size_t d = axis_value + 1; d < input_shape.size(); ++d) {
output_shape.push_back(input_shape[d]);
}
}
return MaceStatus::MACE_SUCCESS;
return output->Resize(output_shape);
}
protected:
const int axis_;
bool keep_dims_;
bool argmin_;
};
const FrameworkType model_type_;
// for Caffe
const bool has_axis_;
const int top_k_;
const bool out_val_;
// for ONNX and TENSORFLOW
const int axis_;
const bool argmin_;
// for ONNX
const bool keep_dims_;
};
void RegisterArgMax(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ArgMax", ArgMaxOp, DeviceType::CPU, float);
......
......@@ -953,6 +953,17 @@ class EltwiseOp : public Operation {
swapped = !swapped;
}
// convert tensor for caffe's broadcast
if (!has_data_format_ && input0->dim_size() == 4) {
if (input1->dim_size() == 2) {
const_cast<Tensor *>(input1)->Reshape(
{input1->dim(0), input1->dim(1), 1, 1});
} else if (input1->dim_size() == 3) {
const_cast<Tensor *>(input1)->Reshape(
{input1->dim(0), input1->dim(1), input1->dim(2), 1});
}
}
// check if we can broadcast tensor
uint32_t rank_diff =
static_cast<uint32_t>(input0->dim_size() - input1->dim_size());
......
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/activation.h"
#include "mace/ops/delegator/activation.h"
#include "mace/utils/memory.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/mvnorm.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
template<DeviceType D, class T>
class GroupNormOp;
template<class T>
class GroupNormOp<DeviceType::CPU, T> : public Operation {
public:
explicit GroupNormOp(OpConstructContext *context)
: Operation(context),
eps_(Operation::GetOptionalArg<float>("epsilon",
static_cast<float>(1e-5))),
group_num_(Operation::GetOptionalArg<int>("group_num", 32)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = Input(INPUT);
Tensor *output = Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
const std::vector<index_t> &input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->Resize(input_shape));
const auto batch = input_shape[0];
const auto channel = input_shape[1];
const auto height = input_shape[2];
const auto width = input_shape[3];
MACE_CHECK(channel % group_num_ == 0,
"group_num_ invalid.", channel, group_num_);
const auto group_size = channel / group_num_;
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
const auto outer_loop = batch * group_num_;
const auto inner_loop = group_size * height * width;
utils::ThreadPool &thread_pool =
context->device()->cpu_runtime()->thread_pool();
auto *scratch_buffer = context->device()->scratch_buffer();
scratch_buffer->Rewind();
auto scratch_buffer_size = outer_loop * sizeof(float) * 2;
MACE_RETURN_IF_ERROR(scratch_buffer->GrowSize(scratch_buffer_size));
float *mean_ptr = scratch_buffer->mutable_data<float>();
float *variance_ptr = mean_ptr + outer_loop;
// compute EX
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
const auto offset = inner_loop * i;
mean_ptr[i] = std::accumulate(input_data + offset,
input_data + offset + inner_loop,
static_cast<T>(0.0f));
mean_ptr[i] /= inner_loop;
}
}, 0, outer_loop, 1);
// compute (X - EX)^2
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
const auto idx = offset + j;
const auto x_ex = input_data[idx] - mean_ptr[i];
output_data[idx] = x_ex * x_ex;
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
// compute (E((X - EX)^2) + eps_)^0.5
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
auto output_data_base = output_data + inner_loop * i;
variance_ptr[i] = std::accumulate(output_data_base,
output_data_base + inner_loop,
static_cast<T>(0.0f));
variance_ptr[i] = std::pow(variance_ptr[i] / inner_loop + eps_, 0.5f);
}
}, 0, outer_loop, 1);
// compute (X - EX) / ((E((X - EX)^2) + eps_)^0.5)
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] =
(input_data[offset + j] - mean_ptr[i]) / variance_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
const float eps_;
const int group_num_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#ifdef MACE_ENABLE_OPENCL
template<>
class GroupNormOp<DeviceType::GPU, float> : public Operation {
public:
explicit GroupNormOp(OpConstructContext *context) : Operation(context) {
const auto group_num = Operation::GetOptionalArg<int>("group_num", 32);
const auto eps = Operation::GetOptionalArg<float>(
"epsilon", static_cast<float>(1e-5));
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::MVNormKernel>(
true, opencl::image::MeanType::GROUP_CHANNELS, eps, group_num);
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.",
input->dim_size());
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLMVNormKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL
void RegisterGroupNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "GroupNorm", GroupNormOp,
DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "GroupNorm", GroupNormOp, DeviceType::CPU);
MACE_REGISTER_GPU_OP(op_registry, "GroupNorm", GroupNormOp);
MACE_REGISTER_OP_CONDITION(
op_registry, OpConditionBuilder("GroupNorm").SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
const int group_num = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "group_num", 32);
auto output_channels = op->output_shape(0).dims()[3];
const int group_size = output_channels / group_num;
if (group_size % 4 == 0) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
}));
}
} // namespace ops
} // namespace mace
......@@ -145,9 +145,12 @@ class MVNormOp<DeviceType::GPU, float> : public Operation {
Operation::GetOptionalArg<bool>("across_channels", false);
auto eps = Operation::GetOptionalArg<float>("epsilon", 1e-9);
auto mean_type = across_channels ?
opencl::image::MeanType::ACROSS_CHANNELS :
opencl::image::MeanType::SINGLE_CHANNEL;
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::MVNormKernel>(
normalize_variance, across_channels, eps);
normalize_variance, mean_type, eps);
} else {
MACE_NOT_IMPLEMENTED;
}
......
#include <common.h>
DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx,
const int hb_idx, const int chan_blks,
const int height, const int width) {
DATA_TYPE4 compute_mean_image(image2d_t input, const int height,
const int width, const int chan_blks,
const int group_blks,
const int batch_idx, const int chan_blk_idx) {
DATA_TYPE4 total = 0.0f;
DATA_TYPE4 mean = 0.0f;
const int hb_base = mul24(hb_idx / height, height);
const int wc_blks = mul24(width, chan_blks);
const int hb_base = mul24(batch_idx, height);
#ifdef ACROSS_CHANNELS
const int wc_blks = mul24(width, chan_blks);
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int pos = 0; pos < wc_blks; ++pos) {
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
......@@ -16,7 +17,23 @@ DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx,
}
}
DATA_TYPE total_value = total.x + total.y + total.z + total.w;
DATA_TYPE mean_value = total_value / (DATA_TYPE)(mul24(mul24(height, wc_blks), 4));
DATA_TYPE mean_value =
total_value / (DATA_TYPE)(mul24(mul24(height, wc_blks), 4));
mean = (DATA_TYPE4){mean_value, mean_value, mean_value, mean_value};
#else
#ifdef GROUP_CHANNELS
const int group_base = chan_blk_idx / group_blks * group_blks;
const int wg_blks_start = mul24(width, group_base);
const int wg_blks_end = wg_blks_start + group_blks * width;
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int pos = wg_blks_start; pos < wg_blks_end; ++pos) {
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
total += in_data;
}
}
DATA_TYPE total_value = total.x + total.y + total.z + total.w;
const int total_num = mul24(mul24(height, wg_blks_end - wg_blks_start), 4);
DATA_TYPE mean_value = total_value / (DATA_TYPE)(total_num);
mean = (DATA_TYPE4){mean_value, mean_value, mean_value, mean_value};
#else
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
......@@ -27,15 +44,42 @@ DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx,
}
}
mean = total / mul24(height, width);
#endif
#endif // GROUP_CHANNELS
#endif // ACROSS_CHANNELS
return mean;
}
__kernel void mvnorm_compute_mean_value(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input,
__private const int height,
__private const int width,
__private const int group_blks,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int batch_idx = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || batch_idx >= global_size_dim1) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int batch = global_size_dim1;
DATA_TYPE4 mean = compute_mean_image(input, height, width, chan_blks,
group_blks, batch_idx, chan_blk_idx);
WRITE_IMAGET(output, (int2)(chan_blk_idx, batch_idx), mean);
}
__kernel void mvnorm_mean(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t mean_image, // E(X)
__private const int height,
__private const int group_blks,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
......@@ -54,18 +98,20 @@ __kernel void mvnorm_mean(OUT_OF_RANGE_PARAMS
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
DATA_TYPE4 mean = compute_mean_image(input, width_idx,
hb_idx, chan_blks, height, width);
DATA_TYPE4 mean = READ_IMAGET(
mean_image, SAMPLER, (int2)(chan_blk_idx, hb_idx / height));
in_data -= mean;
WRITE_IMAGET(output, (int2)(pos, hb_idx), in_data);
}
// compute the (X - EX)^2
__kernel void mvnorm_vn_step1(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__write_only image2d_t mean_image, // E(X)
__read_only image2d_t mean_image, // E(X)
__write_only image2d_t square_image, // (X - EX)^2
__private const int height) {
__private const int height,
__private const int group_blks) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
......@@ -82,23 +128,21 @@ __kernel void mvnorm_vn_step1(OUT_OF_RANGE_PARAMS
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
DATA_TYPE4 mean = compute_mean_image(input, width_idx,
hb_idx, chan_blks, height, width);
DATA_TYPE4 mean =
READ_IMAGET(mean_image, SAMPLER, (int2)(chan_blk_idx, hb_idx / height));
in_data = in_data - mean;
DATA_TYPE4 pow_data = in_data * in_data;
if (hb_idx == 0 && width_idx == 0) {
WRITE_IMAGET(mean_image, (int2)(chan_blk_idx, 0), mean);
}
DATA_TYPE4 pow_data = in_data * in_data;
WRITE_IMAGET(square_image, (int2)(pos, hb_idx), pow_data);
}
// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
__kernel void mvnorm_vn_step2(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t mean_image, // E(X)
__read_only image2d_t square_image, // (X - EX)^2
__read_only image2d_t mean_image_sqr, // E((X - EX)^2)
__private const int height,
__private const int group_blks,
__private const float eps,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
......@@ -115,15 +159,20 @@ __kernel void mvnorm_vn_step2(OUT_OF_RANGE_PARAMS
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = READ_IMAGET(mean_image, SAMPLER, (int2)(chan_blk_idx, 0));
DATA_TYPE4 mean = READ_IMAGET(
mean_image, SAMPLER, (int2)(chan_blk_idx, hb_idx / height));
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data = in_data - mean;
DATA_TYPE4 mean_v = compute_mean_image(square_image, width_idx,
hb_idx, chan_blks, height, width);
DATA_TYPE4 mean_sqr = READ_IMAGET(
mean_image_sqr, SAMPLER, (int2)(chan_blk_idx, hb_idx / height));;
DATA_TYPE4 norm_data = in_data / (sqrt(mean_v) + eps);
#ifdef GROUP_CHANNELS
DATA_TYPE4 norm_data = in_data / sqrt(mean_sqr + eps);
#else
DATA_TYPE4 norm_data = in_data / (sqrt(mean_sqr) + eps);
#endif
WRITE_IMAGET(output, (int2)(pos, hb_idx), norm_data);
}
......@@ -29,101 +29,167 @@ namespace {
MaceStatus BuildMVNKernel(OpenCLRuntime *runtime, cl::Kernel *kernel,
const char *kernel_name,
std::set<std::string> *built_options,
bool across_channel) {
std::stringstream micro_name;
micro_name << "-Dmvnorm=" << kernel_name;
built_options->emplace(micro_name.str());
MeanType mean_type_) {
std::stringstream macro_name;
macro_name << "-Dmvnorm=" << kernel_name;
built_options->emplace(macro_name.str());
built_options->emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options->emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
if (across_channel) {
if (mean_type_ == MeanType::ACROSS_CHANNELS) {
built_options->emplace("-DACROSS_CHANNELS");
} else if (mean_type_ == MeanType::GROUP_CHANNELS) {
built_options->emplace("-DGROUP_CHANNELS");
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("mvnorm", kernel_name,
*built_options, kernel));
return MaceStatus::MACE_SUCCESS;
}
std::unique_ptr<Image> CreateImage(
OpContext *context, const DataType dt,
const std::vector<index_t> &buffer_shape) {
std::unique_ptr<Image> image =
make_unique<Image>(context->device()->allocator());
std::vector<size_t> shape;
OpenCLUtil::CalImage2DShape(
buffer_shape, OpenCLBufferType::IN_OUT_CHANNEL, &shape);
MACE_CHECK(image->Allocate(shape, dt) == MaceStatus::MACE_SUCCESS);
VLOG(1) << "MVNormKernel::CreateImage allocate image_:" << MakeString(shape);
return image;
}
} // namespace
MVNormKernel::MVNormKernel(bool normalize_variance,
bool across_channels, float eps)
MeanType mean_type, float eps, int group_num)
: normalize_variance_(normalize_variance),
across_channels_(across_channels),
eps_(eps) {}
void MVNormKernel::CheckImage(OpContext *context, const DataType dt,
const std::vector<index_t> &square_shape,
const std::vector<index_t> &mean_shape) {
if (square_image_ == nullptr) {
square_image_ = CreateImage(context, dt, square_shape);
}
if (mean_image_ == nullptr) {
mean_image_ = CreateImage(context, dt, mean_shape);
}
}
mean_type_(mean_type),
eps_(eps),
group_num_(group_num) {}
MaceStatus MVNormKernel::Compute(OpContext
*context,
const Tensor *input, Tensor
*output) {
MaceStatus MVNormKernel::Compute(OpContext *context,
const Tensor *input, Tensor *output) {
const auto batch = input->dim(0);
const auto height = input->dim(1);
const auto width = input->dim(2);
const auto channels = input->dim(3);
index_t group_blocks = 0;
if (mean_type_ == MeanType::GROUP_CHANNELS) {
MACE_CHECK(group_num_ > 0, "group num should > 0");
const index_t group = channels / group_num_;
MACE_CHECK(group > 0 && group % 4 == 0, group, " can not be divided by 4");
group_blocks = group / 4;
}
return DoCompute(context, input, output, batch,
height, width, channels, group_blocks);
}
MaceStatus MVNormKernel::DoCompute(
OpContext *context, const Tensor *input, Tensor *output,
const index_t batch, const index_t height, const index_t width,
const index_t channels, const index_t group_blocks) {
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
const std::vector<index_t > mean_shape = {batch, 1, 1, channels};
std::vector<size_t> mean_image_shape;
OpenCLUtil::CalImage2DShape(mean_shape, OpenCLBufferType::IN_OUT_HEIGHT,
&mean_image_shape);
ScratchImageManager *scratch_manager =
context->device()->gpu_runtime()->scratch_image_manager();
ScratchImage scratch_mean_image(scratch_manager);
auto mace_mean_image = scratch_mean_image.Scratch(
context->device()->allocator(), mean_image_shape, input->dtype());
cl::Image *mean_image = static_cast<cl::Image *>(mace_mean_image->buffer());
if (normalize_variance_) {
const std::vector<index_t> &square_shape = input->buffer_shape();
const std::vector<index_t> mean_shape = {1, 1, 1, channels};
CheckImage(context, input->dtype(), square_shape, mean_shape);
// compute the (X - EX)^2
ScratchImage scratch_mean_image_sqr(scratch_manager);
auto mace_mean_image_sqr = scratch_mean_image_sqr.Scratch(
context->device()->allocator(), mean_image_shape, input->dtype());
cl::Image *mean_image_sqr =
static_cast<cl::Image *>(mace_mean_image_sqr->buffer());
// compute the EX
MACE_RETURN_IF_ERROR(ExecuteMeanValueKernel(
context, runtime, batch, height, width, channel_blocks, group_blocks,
input->opencl_image(), mean_image));
// compute (X - EX)^2 to output
MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep1Kernel(
context, runtime, gws, input));
context, runtime, gws, height, group_blocks, input->opencl_image(),
mean_image, output->opencl_image()));
// compute E((X - EX)^2) to mean_image_sqr_
MACE_RETURN_IF_ERROR(ExecuteMeanValueKernel(
context, runtime, batch, height, width, channel_blocks, group_blocks,
output->opencl_image(), mean_image_sqr));
// compute the compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep2Kernel(
context, runtime, gws, input, output));
context, runtime, gws, height, group_blocks, input->opencl_image(),
mean_image, mean_image_sqr, output->opencl_image()));
} else {
// compute the EX
MACE_RETURN_IF_ERROR(ExecuteMeanValueKernel(
context, runtime, batch, height, width, channel_blocks, group_blocks,
input->opencl_image(), mean_image));
// compute the (X - EX)
MACE_RETURN_IF_ERROR(ExecuteMeanNormKernel(
context, runtime, gws, input, output));
context, runtime, gws, height, group_blocks, input->opencl_image(),
mean_image, output->opencl_image()));
}
return
MaceStatus::MACE_SUCCESS;
return MaceStatus::MACE_SUCCESS;
}
MaceStatus MVNormKernel::ExecuteMeanValueKernel(OpContext *context,
OpenCLRuntime *runtime,
const index_t batch,
const index_t height,
const index_t width,
const index_t channel_blocks,
const index_t group_blocks,
const cl::Image *input_image,
cl::Image *output_image) {
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_mean_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(
BuildMVNKernel(runtime, &kernel_mean_, "mvnorm_compute_mean_value",
&built_options, mean_type_));
kwg_size_mean_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_mean_));
}
const uint32_t gws[2] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(batch)};
const std::vector<uint32_t> lws = {static_cast<uint32_t>(kwg_size_mean_) / 8,
8, 0};
MACE_OUT_OF_RANGE_INIT(kernel_mean_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_mean_);
MACE_SET_2D_GWS_ARGS(kernel_mean_, gws);
kernel_mean_.setArg(idx++, *input_image);
kernel_mean_.setArg(idx++, static_cast<int>(height));
kernel_mean_.setArg(idx++, static_cast<int>(width));
kernel_mean_.setArg(idx++, static_cast<int>(group_blocks));
kernel_mean_.setArg(idx++, *output_image);
std::string tuning_key = Concat(
"mvnorm_compute_mean_opencl_kernel", gws[0],
gws[1], normalize_variance_, mean_type_);
MACE_RETURN_IF_ERROR(TuningOrRun2DKernel(runtime, kernel_mean_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output) {
const auto height = input->dim(1);
const index_t height,
const index_t group_blocks,
const cl::Image *input,
const cl::Image *mean_image,
cl::Image *output) {
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step1_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_, "mvnorm_mean",
&built_options, across_channels_));
&built_options, mean_type_));
kwg_size_step1_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step1_));
}
......@@ -132,14 +198,16 @@ MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context,
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_);
MACE_SET_3D_GWS_ARGS(kernel_step1_, gws);
kernel_step1_.setArg(idx++, *(input->opencl_image()));
kernel_step1_.setArg(idx++, *input);
kernel_step1_.setArg(idx++, *mean_image);
kernel_step1_.setArg(idx++, static_cast<int>(height));
kernel_step1_.setArg(idx++, *(output->opencl_image()));
kernel_step1_.setArg(idx++, static_cast<int>(group_blocks));
kernel_step1_.setArg(idx++, *output);
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step1_);
std::string
tuning_key = Concat("mvnorm_mean_opencl_kernel", gws[0], gws[1], gws[2],
normalize_variance_, across_channels_);
std::string tuning_key = Concat("mvnorm_mean_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
mean_type_, group_blocks);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key,
gws, lws, context->future()));
......@@ -147,12 +215,11 @@ MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context,
return MaceStatus::MACE_SUCCESS;
}
// The first step of compute Variance Norm, compute the (X - EX)^2
// store them into the square_image_
// compute the (X - EX)^2
MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel(
OpContext *context, OpenCLRuntime *runtime,
const uint32_t (&gws)[3], const Tensor *input) {
const auto height = input->dim(1);
const uint32_t (&gws)[3], const index_t height, const index_t group_blocks,
const cl::Image *input, const cl::Image *mean_image, cl::Image *output) {
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step1_.get() == nullptr) {
std::set<std::string> built_options;
......@@ -160,7 +227,7 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel(
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_,
"mvnorm_vn_step1",
&built_options, across_channels_));
&built_options, mean_type_));
kwg_size_step1_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step1_));
}
......@@ -169,18 +236,17 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel(
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_);
MACE_SET_3D_GWS_ARGS(kernel_step1_, gws);
kernel_step1_.setArg(idx++, *(input->opencl_image()));
cl::Image *mean_image = static_cast<cl::Image *>(mean_image_->buffer());
kernel_step1_.setArg(idx++, *input);
kernel_step1_.setArg(idx++, *mean_image);
cl::Image *square_image = static_cast<cl::Image *>(square_image_->buffer());
kernel_step1_.setArg(idx++, *square_image);
kernel_step1_.setArg(idx++, *output);
kernel_step1_.setArg(idx++, static_cast<int>(height));
kernel_step1_.setArg(idx++, static_cast<int>(group_blocks));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step1_);
std::string
tuning_key = Concat("mvnorm_v_step1_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
across_channels_);
mean_type_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key,
gws, lws, context->future()));
......@@ -188,12 +254,12 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel(
return MaceStatus::MACE_SUCCESS;
}
// The second step of compute Variance Norm, read the (X - EX)^2 from
// square_image_ and compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel(
OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3],
const Tensor *input, Tensor *output) {
const auto height = input->dim(1);
const index_t height, const index_t group_blocks,
const cl::Image *input, const cl::Image *mean_image,
const cl::Image *mean_image_sqr, cl::Image *output) {
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step2_.get() == nullptr) {
std::set<std::string> built_options;
......@@ -201,7 +267,7 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel(
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step2_,
"mvnorm_vn_step2",
&built_options, across_channels_));
&built_options, mean_type_));
kwg_size_step2_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step2_));
}
......@@ -210,20 +276,19 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel(
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step2_);
MACE_SET_3D_GWS_ARGS(kernel_step2_, gws);
kernel_step2_.setArg(idx++, *(input->opencl_image()));
cl::Image *mean_image = static_cast<cl::Image *>(mean_image_->buffer());
kernel_step2_.setArg(idx++, *input);
kernel_step2_.setArg(idx++, *mean_image);
cl::Image *square_image = static_cast<cl::Image *>(square_image_->buffer());
kernel_step2_.setArg(idx++, *square_image);
kernel_step2_.setArg(idx++, *mean_image_sqr);
kernel_step2_.setArg(idx++, static_cast<int>(height));
kernel_step2_.setArg(idx++, static_cast<int>(group_blocks));
kernel_step2_.setArg(idx++, static_cast<float>(eps_));
kernel_step2_.setArg(idx++, *(output->opencl_image()));
kernel_step2_.setArg(idx++, *output);
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step2_);
std::string
tuning_key = Concat("mvnorm_v_step2_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
across_channels_);
mean_type_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step2_, tuning_key,
gws, lws, context->future()));
......
......@@ -28,48 +28,79 @@ namespace ops {
namespace opencl {
namespace image {
enum MeanType {
SINGLE_CHANNEL,
GROUP_CHANNELS,
ACROSS_CHANNELS,
};
class MVNormKernel : public OpenCLMVNormKernel {
public:
explicit MVNormKernel(bool normalize_variance_,
bool across_channels, float eps);
explicit MVNormKernel(bool normalize_variance_, MeanType mean_type,
float eps, int group_num = 0);
~MVNormKernel() = default;
MaceStatus Compute(
OpContext *context, const Tensor *input, Tensor *output) override;
private:
void CheckImage(OpContext *context, const DataType dt,
const std::vector<index_t> &square_shape,
const std::vector<index_t> &mean_shape);
MaceStatus DoCompute(OpContext *context, const Tensor *input,
Tensor *output, const index_t batch,
const index_t height, const index_t width,
const index_t channels, const index_t group_blocks);
MaceStatus ExecuteMeanValueKernel(OpContext *context,
OpenCLRuntime *runtime,
const index_t batch,
const index_t height,
const index_t width,
const index_t channel_blocks,
const index_t group_blocks,
const cl::Image *input_image,
cl::Image *output_image);
MaceStatus ExecuteMeanNormKernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output);
const index_t height,
const index_t group_blocks,
const cl::Image *input,
const cl::Image *mean_image,
cl::Image *output);
// compute the (X - EX)^2
MaceStatus ExecuteVarianceNormStep1Kernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input);
const index_t height,
const index_t group_blocks,
const cl::Image *input,
const cl::Image *mean_image,
cl::Image *output);
// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MaceStatus ExecuteVarianceNormStep2Kernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output);
const index_t height,
const index_t group_blocks,
const cl::Image *input,
const cl::Image *mean_image,
const cl::Image *mean_image_sqr,
cl::Image *output);
private:
bool normalize_variance_;
bool across_channels_;
float eps_;
const bool normalize_variance_;
const MeanType mean_type_;
const float eps_;
const int group_num_;
cl::Kernel kernel_mean_;
uint32_t kwg_size_mean_;
cl::Kernel kernel_step1_;
uint32_t kwg_size_step1_;
cl::Kernel kernel_step2_;
uint32_t kwg_size_step2_;
// the cache of (X - EX)^2
std::unique_ptr<Image> square_image_;
// the cache of EX
std::unique_ptr<Image> mean_image_;
};
} // namespace image
......
......@@ -24,23 +24,13 @@ namespace image {
MaceStatus ResizeNearestNeighborKernel::Compute(
OpContext *context,
const Tensor *input,
const Tensor *size,
const std::vector<index_t> &dims,
const index_t out_height,
const index_t out_width,
Tensor *output) {
const index_t batch = input->dim(0);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
index_t out_height = 0;
index_t out_width = 0;
if (dims.size() < 2) {
Tensor::MappingGuard size_mapper(size);
out_height = size->data<int32_t>()[0];
out_width = size->data<int32_t>()[1];
} else {
out_height = dims[0];
out_width = dims[1];
}
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......
......@@ -72,8 +72,8 @@ class ResizeNearestNeighborKernel : public OpenCLResizeNearestNeighborKernel {
MaceStatus Compute(
OpContext *context,
const Tensor *input,
const Tensor *size,
const std::vector<index_t> &dims,
const index_t out_height,
const index_t out_width,
Tensor *output) override;
private:
......
......@@ -16,7 +16,6 @@
#define MACE_OPS_OPENCL_MVNORM_H_
#include "mace/public/mace.h"
#include "mace/utils/math.h"
namespace mace {
......
......@@ -32,8 +32,8 @@ class OpenCLResizeNearestNeighborKernel {
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
const Tensor *size,
const std::vector<index_t> &dims,
const index_t out_height,
const index_t out_width,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLResizeNearestNeighborKernel);
};
......
......@@ -41,6 +41,7 @@ extern void RegisterExtractPooling(OpRegistry *op_registry);
extern void RegisterFill(OpRegistry *op_registry);
extern void RegisterFullyConnected(OpRegistry *op_registry);
extern void RegisterGather(OpRegistry *op_registry);
extern void RegisterGroupNorm(OpRegistry *op_registry);
extern void RegisterIdentity(OpRegistry *op_registry);
extern void RegisterIfDefined(OpRegistry *op_registry);
extern void RegisterInferConv2dShape(OpRegistry *op_registry);
......@@ -120,6 +121,7 @@ void RegisterAllOps(OpRegistry *registry) {
ops::RegisterFill(registry);
ops::RegisterFullyConnected(registry);
ops::RegisterGather(registry);
ops::RegisterGroupNorm(registry);
ops::RegisterIdentity(registry);
ops::RegisterIfDefined(registry);
ops::RegisterInferConv2dShape(registry);
......
......@@ -78,27 +78,37 @@ class ResizeNearestNeighborOp<DeviceType::CPU, T> : public Operation {
public:
explicit ResizeNearestNeighborOp(OpConstructContext *context)
: Operation(context),
align_corners_(Operation::GetOptionalArg<bool>("align_corners",
false)) {}
align_corners_(Operation::GetOptionalArg<bool>("align_corners", false)),
height_scale_(Operation::GetOptionalArg<float>("height_scale", 0)),
width_scale_(Operation::GetOptionalArg<float>("width_scale", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *size = this->Input(1);
Tensor::MappingGuard size_mapper(size);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1,
"input must be 4-dimensional and size must be 1-dimensional. ",
input->dim_size(), size->dim_size());
MACE_CHECK(input->dim_size() == 4,
"input must be 4-dimensional.", input->dim_size());
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
const index_t out_height = size->data<int32_t>()[0];
const index_t out_width = size->data<int32_t>()[1];
index_t out_height = 0;
index_t out_width = 0;
if (height_scale_ > 0) { // for Caffe
out_height = static_cast<index_t>(height_scale_ * in_height);
out_width = static_cast<index_t>(width_scale_ * in_width);
} else { // for tensor (Tf and ONNX)
const Tensor *size = this->Input(1);
Tensor::MappingGuard size_mapper(size);
MACE_CHECK(size->dim_size() == 1,
"size must be 1-dimensional.", size->dim_size());
out_height = size->data<int32_t>()[0];
out_width = size->data<int32_t>()[1];
}
MACE_CHECK(out_height > 0 && out_width > 0, out_height, out_width);
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......@@ -114,14 +124,15 @@ class ResizeNearestNeighborOp<DeviceType::CPU, T> : public Operation {
return MaceStatus::MACE_SUCCESS;
}
float height_scale =
common::utils::CalculateResizeScale(in_height,
out_height,
align_corners_);
float width_scale =
common::utils::CalculateResizeScale(in_width,
out_width,
align_corners_);
// Caffe's scale is the opposite of ours
float height_scale = height_scale_ > 0 ? 1 / height_scale_ :
common::utils::CalculateResizeScale(in_height,
out_height,
align_corners_);
float width_scale = width_scale_ > 0 ? 1 / width_scale_ :
common::utils::CalculateResizeScale(in_width,
out_width,
align_corners_);
ResizeImageNCHW(context,
input_data,
batch,
......@@ -139,6 +150,8 @@ class ResizeNearestNeighborOp<DeviceType::CPU, T> : public Operation {
private:
bool align_corners_;
float height_scale_;
float width_scale_;
};
#ifdef MACE_ENABLE_OPENCL
......@@ -146,7 +159,9 @@ template<>
class ResizeNearestNeighborOp<DeviceType::GPU, float> : public Operation {
public:
explicit ResizeNearestNeighborOp(OpConstructContext *context)
: Operation(context), dim_(Operation::GetRepeatedArgs<index_t>("dim")) {
: Operation(context), dim_(Operation::GetRepeatedArgs<index_t>("dim")),
height_scale_(Operation::GetOptionalArg<float>("height_scale", 0)),
width_scale_(Operation::GetOptionalArg<float>("width_scale", 0)) {
bool align_corners = Operation::GetOptionalArg<bool>(
"align_corners", false);
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
......@@ -158,17 +173,34 @@ class ResizeNearestNeighborOp<DeviceType::GPU, float> : public Operation {
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
const Tensor *size = this->Input(1);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1,
"input must be 4-dimensional and size must be 1-dimensional.",
input->dim_size(), size->dim_size());
MACE_CHECK(input->dim_size() == 4,
"input must be 4-dimensional.", input->dim_size());
index_t out_height = 0;
index_t out_width = 0;
if (height_scale_ > 0) { // for Caffe
out_height = static_cast<index_t>(height_scale_ * input->dim(1));
out_width = static_cast<index_t>(width_scale_ * input->dim(2));
} else if (dim_.size() < 2) { // for variable tensor (Tf and ONNX)
const Tensor *size = this->Input(1);
Tensor::MappingGuard size_mapper(size);
MACE_CHECK(size->dim_size() == 1,
"size must be 1-dimensional.", size->dim_size());
out_height = size->data<int32_t>()[0];
out_width = size->data<int32_t>()[1];
} else { // for const tensor (Tf and ONNX)
out_height = dim_[0];
out_width = dim_[1];
}
return kernel_->Compute(context, input, size, dim_, output);
return kernel_->Compute(context, input, out_height, out_width, output);
}
private:
std::vector<index_t> dim_;
float height_scale_;
float width_scale_;
std::unique_ptr<OpenCLResizeNearestNeighborKernel> kernel_;
};
#endif // MACE_ENABLE_OPENCL
......
......@@ -95,7 +95,7 @@ class Tuner {
std::vector<param_type> opt_param = default_param;
RetType res = Tune<RetType>(param_generator, func, timer, &opt_param);
VLOG(3) << "Tuning " << param_key
<< " retult: " << MakeString(opt_param);
<< " result: " << MakeString(opt_param);
param_table_[obfucated_param_key] = opt_param;
return res;
} else {
......
......@@ -37,6 +37,7 @@ void ArgMaxTest(const std::vector<index_t> &input_shape,
.Input("Input")
.Input("axis")
.Output("Output")
.AddIntArg("keepdims", 0)
.OutputType({DT_INT32})
.Finalize(net.NewOperatorDef());
// Run
......
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/types.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class GroupNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestGroupNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
int group_num,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<D, T>(MakeString("Input"), input_shape, input);
if (D == DeviceType::CPU) {
net.TransformDataFormat<CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
}
OpDefBuilder("GroupNorm", "GroupNormTest")
.Input(D == DeviceType::CPU ? "InputNCHW" : "Input")
.AddIntArg("group_num", group_num)
.Output(D == DeviceType::CPU ? "OutputNCHW" : "Output")
.Finalize(net.NewOperatorDef());
net.RunOp(D);
if (D == DeviceType::CPU) {
net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
}
net.AddInputFromArray<D, T>("ExpectedOutput", input_shape, output);
if (DataTypeToEnum<T>::value == DT_HALF) {
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"), 1e-2, 1e-2);
} else {
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"), 1e-3);
}
}
} // namespace
TEST_F(GroupNormOpTest, SimpleTestCPU) {
TestGroupNorm<DeviceType::CPU, float>(
{1, 1, 2, 64},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
1, 2, 3, 4, 5, 6, 7, 8},
16,
{-1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.80467, 0.928465, 1.05226,
1.17606, -1.52746, -1.09104, -0.654625, -0.218208, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.17606, -1.05226, -0.928465, -0.80467, 0.218208,
0.654625, 1.09104, 1.52746});
}
TEST_F(GroupNormOpTest, SimpleTestOpenCL) {
TestGroupNorm<DeviceType::GPU, float>(
{1, 1, 2, 64},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
1, 2, 3, 4, 5, 6, 7, 8},
16,
{-1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318,
1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746,
-1.09104, -0.654625, -0.218208, 0.80467, 0.928465, 1.05226,
1.17606, -1.52746, -1.09104, -0.654625, -0.218208, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208,
0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104,
1.52746, -1.17606, -1.05226, -0.928465, -0.80467, 0.218208,
0.654625, 1.09104, 1.52746});
}
TEST_F(GroupNormOpTest, SimpleRandomOPENCL) {
static unsigned int seed = time(NULL);
index_t batch = 1 + rand_r(&seed) % 5;
index_t group = 4 + 4 * (rand_r(&seed) % 3);
index_t group_num = 2 + rand_r(&seed) % 16;
index_t channels = group * group_num;
index_t height = 64;
index_t width = 64;
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input",
{batch, height, width, channels});
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph
OpDefBuilder("GroupNorm", "GroupNormTest")
.Input("InputNCHW")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.AddIntArg("group_num", group_num)
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check
auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output"));
// Run on opencl
OpDefBuilder("GroupNorm", "GroupNormTest")
.Input("Input")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.AddIntArg("group_num", group_num)
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::GPU);
// Tuning
setenv("MACE_TUNING", "1", 1);
net.Run();
unsetenv("MACE_TUNING");
// Run on opencl
net.Run();
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"),
1e-5, 1e-4);
}
TEST_F(GroupNormOpTest, SimpleRandomHalfOPENCL) {
// generate random input
static unsigned int seed = time(NULL);
index_t batch = 1 + rand_r(&seed) % 5;
index_t group = 4 + 4 * (rand_r(&seed) % 16);
index_t group_num = 2 + rand_r(&seed) % 16;
index_t channels = group * group_num;
index_t height = 64;
index_t width = 64;
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input",
{batch, height, width, channels});
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph
OpDefBuilder("GroupNorm", "GroupNormTest")
.Input("InputNCHW")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.AddIntArg("group_num", group_num)
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check
auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output"));
// Run on opencl
OpDefBuilder("GroupNorm", "GroupNormTest")
.Input("Input")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.AddIntArg("group_num", group_num)
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::GPU);
// Tuning
setenv("MACE_TUNING", "1", 1);
net.Run();
unsetenv("MACE_TUNING");
// Run on opencl
net.Run();
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"),
1e-1, 1e-2);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -1833,6 +1833,8 @@ message V1LayerParameter {
optional TransformationParameter transform_param = 36;
optional LossParameter loss_param = 42;
optional V0LayerParameter layer = 1;
optional ResizeNearestParameter resize_nearest_param = 204;
optional GroupNormParameter group_norm_param = 205;
}
// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters
......@@ -1946,3 +1948,14 @@ message ShuffleChannelParameter {
message L2NormalizationParameter {
optional int32 axis = 1 [default = 1];
}
message GroupNormParameter {
optional float eps = 1 [default = 1e-5];
optional int32 group_num = 2 [default = 32];
}
message ResizeNearestParameter {
optional float height_scale = 1 [default = 2.0];
optional float width_scale = 2 [default = 2.0];
}
......@@ -100,6 +100,7 @@ MaceSupportedOps = [
'Concat',
'Conv2D',
'Crop',
'Cumsum',
'Deconv2D',
'DepthToSpace',
'DepthwiseConv2d',
......@@ -111,15 +112,18 @@ MaceSupportedOps = [
'Fill',
'FullyConnected',
'Gather',
'GroupNorm',
'Identity',
'IfDefined',
'InferConv2dShape',
'KaldiBatchNorm',
'LocalResponseNorm',
'LpNorm',
'LSTMCell',
'LstmNonlinear',
'DynamicLSTM',
'MatMul',
'MVNorm',
'OneHot',
'Pad',
'PadContext',
......@@ -153,11 +157,8 @@ MaceSupportedOps = [
'Subsample',
'SumGroup',
'TargetRMSNorm',
'Transpose',
'Cumsum',
'Tile',
'LpNorm',
'MVNorm',
'Transpose',
]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
......@@ -177,7 +178,8 @@ MaceFixedDataFormatOps = [MaceOp.BatchNorm,
MaceOp.SpaceToBatchND,
MaceOp.SpaceToDepth,
MaceOp.LpNorm,
MaceOp.MVNorm]
MaceOp.MVNorm,
MaceOp.GroupNorm]
MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.AddN,
......@@ -221,6 +223,8 @@ class MaceKeyword(object):
mace_batch_to_space_crops_str = 'crops'
mace_paddings_str = 'paddings'
mace_align_corners_str = 'align_corners'
mace_height_scale_str = 'height_scale'
mace_width_scale_str = 'width_scale'
mace_space_batch_block_shape_str = 'block_shape'
mace_space_depth_block_size_str = 'block_size'
mace_constant_value_str = 'constant_value'
......@@ -252,11 +256,14 @@ class MaceKeyword(object):
mace_opencl_mem_type = "opencl_mem_type"
mace_framework_type_str = "framework_type"
mace_group_str = "group"
mace_group_num_str = "group_num"
mace_wino_arg_str = "wino_block_size"
mace_quantize_flag_arg_str = "quantize_flag"
mace_epsilon_str = 'epsilon'
mace_reduce_type_str = 'reduce_type'
mace_argmin_str = 'argmin'
mace_out_val_str = 'out_val'
mace_top_k_str = 'top_k'
mace_round_mode_str = 'round_mode'
mace_min_size_str = 'min_size'
mace_max_size_str = 'max_size'
......
......@@ -161,6 +161,7 @@ class CaffeConverter(base_converter.ConverterInterface):
}
activation_type = {
'ReLU': ActivationType.RELU,
'ReLU6': ActivationType.RELUX,
'PReLU': ActivationType.PRELU,
'TanH': ActivationType.TANH,
'Sigmoid': ActivationType.SIGMOID,
......@@ -175,6 +176,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'Eltwise': self.convert_elementwise,
'Add': self.convert_add,
'ReLU': self.convert_activation,
'ReLU6': self.convert_activation,
'TanH': self.convert_activation,
'Sigmoid': self.convert_activation,
'PReLU': self.convert_activation,
......@@ -186,6 +188,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'InnerProduct': self.convert_fully_connected,
'Interp': self.convert_interp,
'BatchNorm': self.convert_folded_batchnorm,
'GroupNorm': self.convert_group_norm,
'Crop': self.convert_crop,
'Scale': self.convert_scale,
'ShuffleChannel': self.convert_channel_shuffle,
......@@ -196,7 +199,9 @@ class CaffeConverter(base_converter.ConverterInterface):
'L2Normalization': self.convert_lpnorm,
'L1Normalization': self.convert_lpnorm,
'MVN': self.convert_MVN,
'Bias': self.convert_Bias,
'Bias': self.convert_bias,
'ArgMax': self.convert_argmax,
'ResizeNearest': self.convert_resize_nearest,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -254,7 +259,7 @@ class CaffeConverter(base_converter.ConverterInterface):
for op in ops:
for i in six.moves.range(len(op.output)):
original_output_name = op.output[i].split('#')[0]
if original_output_name not in visited and\
if original_output_name not in visited and \
original_output_name not in self._option.input_nodes:
self.replace_input_name(
consumers.get(op.output[i], []),
......@@ -456,6 +461,7 @@ class CaffeConverter(base_converter.ConverterInterface):
filter_data = caffe_op.blobs[0]
self.add_tensor(filter_tensor_name, filter_data.shape,
mace_pb2.DT_FLOAT, filter_data)
print("convert conv2d, the filter shape is: ", filter_data.shape)
op.input.extend([filter_tensor_name])
if len(caffe_op.blobs) == 2:
......@@ -499,16 +505,18 @@ class CaffeConverter(base_converter.ConverterInterface):
self.add_tensor(alpha_tensor_name, alpha_data.reshape(-1).shape,
mace_pb2.DT_FLOAT, alpha_data)
op.input.extend([alpha_tensor_name])
negative_slope = caffe_op.layer.relu_param.negative_slope
if caffe_op.type == 'ReLU' and negative_slope != 0:
param_arg = op.arg.add()
param_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str # noqa
param_arg.f = caffe_op.layer.relu_param.negative_slope
type_arg.s = six.b(ActivationType.LEAKYRELU.name)
if caffe_op.type == 'Clip':
elif caffe_op.type == 'ReLU':
negative_slope = caffe_op.layer.relu_param.negative_slope
if negative_slope != 0:
param_arg = op.arg.add()
param_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str # noqa
param_arg.f = caffe_op.layer.relu_param.negative_slope
type_arg.s = six.b(ActivationType.LEAKYRELU.name)
elif caffe_op.type == 'ReLU6':
limit_arg = op.arg.add()
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0
elif caffe_op.type == 'Clip':
mace_check(caffe_op.layer.clip_param.min == 0,
"Mace only supports min == 0 Clip op")
limit_arg = op.arg.add()
......@@ -548,6 +556,24 @@ class CaffeConverter(base_converter.ConverterInterface):
op.input.extend([name for name in input_names])
op.output[:] = scale_op.layer.top[:]
def convert_group_norm(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.GroupNorm.name
epsilon_arg = op.arg.add()
epsilon_arg.name = MaceKeyword.mace_epsilon_str
group_num_arg = op.arg.add()
group_num_arg.name = MaceKeyword.mace_group_num_str
if hasattr(caffe_op, 'layer') and \
hasattr(caffe_op.layer, 'group_norm_param'):
param = caffe_op.layer.group_norm_param
epsilon_arg.f = param.eps
group_num_arg.i = param.group_num
else:
epsilon_arg.f = 1e-5
group_num_arg.i = 32
def convert_pooling(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.pooling_param
......@@ -668,11 +694,12 @@ class CaffeConverter(base_converter.ConverterInterface):
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = EltwiseType.PROD.value
scale_tensor_name = scale_op_name + '_scale'
scale_data = caffe_op.blobs[0]
self.add_tensor(scale_tensor_name, scale_data.shape,
mace_pb2.DT_FLOAT, scale_data)
op.input.extend([scale_tensor_name])
if len(caffe_op.blobs) >= 1:
scale_tensor_name = scale_op_name + '_scale'
scale_data = caffe_op.blobs[0]
self.add_tensor(scale_tensor_name, scale_data.shape,
mace_pb2.DT_FLOAT, scale_data)
op.input.extend([scale_tensor_name])
if len(caffe_op.blobs) == 2:
bias_tensor_name = scale_op_name + '_offset'
......@@ -802,8 +829,9 @@ class CaffeConverter(base_converter.ConverterInterface):
mace_check(step_w_arg.f > 0, "step_w should be larger than 0.")
if param.HasField('step'):
mace_check(not param.HasField('step_h') and not param.HasField('step_w'), # noqa
"Either step or step_h/step_w should be specified; not both.") # noqa
mace_check(
not param.HasField('step_h') and not param.HasField('step_w'),
"Either step or step_h/step_w should be specified; not both.")
mace_check(param.step > 0, "step should be larger than 0.")
step_h_arg.f = param.step
step_w_arg.f = param.step
......@@ -869,7 +897,7 @@ class CaffeConverter(base_converter.ConverterInterface):
eps_arg.name = MaceKeyword.mace_epsilon_str
eps_arg.f = param.eps
def convert_Bias(self, caffe_op):
def convert_bias(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.BiasAdd.name
param = caffe_op.layer.bias_param
......@@ -882,3 +910,58 @@ class CaffeConverter(base_converter.ConverterInterface):
mace_check(param.axis == 0 or param.axis == 1,
"BiasAdd only support axis with 0 or 1.")
axis_arg.i = param.axis
if len(caffe_op.blobs) >= 1:
bias_tensor_name = op.name + '_bias'
bias_data = caffe_op.blobs[0]
self.add_tensor(bias_tensor_name, bias_data.shape,
mace_pb2.DT_FLOAT, bias_data)
op.input.extend([bias_tensor_name])
def convert_resize_nearest(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.ResizeNearestNeighbor.name
align_corners_arg = op.arg.add()
align_corners_arg.name = MaceKeyword.mace_align_corners_str
align_corners_arg.i = 0
height_scale_arg = op.arg.add()
height_scale_arg.name = MaceKeyword.mace_height_scale_str
width_scale_arg = op.arg.add()
width_scale_arg.name = MaceKeyword.mace_width_scale_str
if hasattr(caffe_op, 'layer') and \
hasattr(caffe_op.layer, 'resize_nearest_param'):
param = caffe_op.layer.resize_nearest_param
height_scale_arg.f = param.height_scale
width_scale_arg.f = param.width_scale
else:
height_scale_arg.f = 2.0
width_scale_arg.f = 2.0
def convert_argmax(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.ArgMax.name
out_max_val = False
if hasattr(caffe_op, 'layer') and \
hasattr(caffe_op.layer, 'argmax_param'):
param = caffe_op.layer.argmax_param
if hasattr(param, 'out_max_val'):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_out_val_str
axis_arg.i = param.out_max_val
out_max_val = param.out_max_val
if hasattr(param, MaceKeyword.mace_top_k_str):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_top_k_str
axis_arg.i = param.top_k
if hasattr(param, MaceKeyword.mace_axis_str):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = param.axis
if out_max_val:
op.output_type.extend([mace_pb2.DT_FLOAT])
else:
op.output_type.extend([mace_pb2.DT_INT32])
......@@ -36,8 +36,9 @@ class ShapeInference(object):
MaceOp.Deconv2D.name: self.infer_shape_deconv,
MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape,
MaceOp.DepthwiseDeconv2d.name: self.infer_shape_deconv,
MaceOp.Eltwise.name: self.infer_shape_general,
MaceOp.Eltwise.name: self.infer_shape_eltwise,
MaceOp.BatchNorm.name: self.infer_shape_general,
MaceOp.GroupNorm.name: self.infer_shape_general,
MaceOp.AddN.name: self.infer_shape_general,
MaceOp.Activation.name: self.infer_shape_general,
MaceOp.Pooling.name: self.infer_shape_conv_pool_shape,
......@@ -54,6 +55,9 @@ class ShapeInference(object):
MaceOp.ResizeBilinear.name: self.infer_shape_resize_bilinear,
MaceOp.LpNorm.name: self.infer_shape_general,
MaceOp.MVNorm.name: self.infer_shape_general,
MaceOp.ResizeNearestNeighbor.name:
self.infer_shape_nearest_neighbor,
MaceOp.ArgMax.name: self.infer_shape_argmax,
}
self._net = net
......@@ -131,7 +135,7 @@ class ShapeInference(object):
output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # noqa
and ConverterUtil.filter_format(self._net) == DataFormat.OIHW:
# filter format: OIHW
if op.type == MaceOp.DepthwiseConv2d.name:
output_shape[1] = filter_shape[0] * filter_shape[1]
......@@ -172,7 +176,7 @@ class ShapeInference(object):
MaceKeyword.mace_group_str)
output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # noqa
and ConverterUtil.filter_format(self._net) == DataFormat.OIHW:
# filter format: IOHW
output_shape[1] = filter_shape[1]
if group_arg is not None and group_arg.i > 1:
......@@ -250,9 +254,12 @@ class ShapeInference(object):
input_shape = list(self._output_shape_cache[op.input[0]])
input_w = input_shape[3]
input_h = input_shape[2]
min_size = ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats # noqa
max_size = ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats # noqa
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
min_size = \
ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats
max_size = \
ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats
aspect_ratio = \
ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats
num_prior = len(aspect_ratio) * len(min_size) + len(max_size)
output_shape[2] = int(num_prior * input_h * input_w * 4)
......@@ -282,7 +289,8 @@ class ShapeInference(object):
else:
output_shape = []
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
end_axis = ConverterUtil.get_arg(op, MaceKeyword.mace_end_axis_str).i # noqa
end_axis = ConverterUtil.get_arg(op,
MaceKeyword.mace_end_axis_str).i
end_axis = end_axis if end_axis > 0 else end_axis + len(
list(self._output_shape_cache[op.input[0]]))
dim = 1
......@@ -310,3 +318,73 @@ class ShapeInference(object):
mace_check(False, "format %s is not supported"
% ConverterUtil.data_format(op))
self.add_output_shape(op, [output_shape])
def infer_shape_nearest_neighbor(self, op):
input_shape = self._output_shape_cache[op.input[0]]
height_scale = \
ConverterUtil.get_arg(op, MaceKeyword.mace_height_scale_str).f
width_scale = \
ConverterUtil.get_arg(op, MaceKeyword.mace_width_scale_str).f
if ConverterUtil.data_format(op) == DataFormat.NCHW:
output_shape = [input_shape[0], input_shape[1],
int(input_shape[2] * height_scale),
int(input_shape[3] * width_scale)]
elif ConverterUtil.data_format(op) == DataFormat.NHWC:
output_shape = [input_shape[0], int(input_shape[2] * height_scale),
int(input_shape[3] * width_scale), input_shape[3]]
else:
output_shape = []
mace_check(False, "format %s is not supported"
% ConverterUtil.data_format(op))
self.add_output_shape(op, [output_shape])
def infer_shape_argmax(self, op):
input_shape = self._output_shape_cache[op.input[0]]
output_dim_num = len(input_shape)
if output_dim_num < 3:
output_dim_num = 3
axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
has_axis = (axis_arg is not None)
axis_value = 0
if has_axis:
axis_value = axis_arg.i
if axis_value < 0:
axis_value = len(input_shape) + axis_value
top_k = ConverterUtil.get_arg(op, MaceKeyword.mace_top_k_str).i
mace_check(top_k >= 1, "Invalid top_k value")
out_val = ConverterUtil.get_arg(op, MaceKeyword.mace_out_val_str).i
if has_axis: # Produces max_ind or max_val per axis
output_shape = input_shape
output_shape[axis_value] = top_k
else:
output_shape = [1] * output_dim_num
output_shape[0] = input_shape[0]
output_shape[2] = top_k
if out_val: # Produces max_ind and max_val
output_shape[1] = 2
self.add_output_shape(op, [output_shape])
def infer_shape_eltwise(self, op):
input_num = len(op.input)
mace_check(input_num > 0, "input num should > 0")
max_idx = 0
max_input_size = 0
for i in range(0, input_num):
mace_check(op.input[i] in self._output_shape_cache,
"Op %s input %s does not exist"
% (op.name, op.input[i]))
input_shape = self._output_shape_cache[op.input[i]]
input_size = 1
for k in range(0, len(input_shape)):
input_size *= input_shape[k]
if input_size > max_input_size:
max_idx = i
max_input_size = input_size
input_max_shape = self._output_shape_cache[op.input[max_idx]]
self.add_output_shape(op, [input_max_shape])
......@@ -1046,6 +1046,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.type = MaceOp.ArgMax.name
op.output_type.extend([mace_pb2.DT_INT32])
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = 0
def convert_split(self, tf_op):
op = self.convert_general_op(tf_op)
num_or_size_splits = tf_op.get_attr('num_split')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册