提交 a610d506 编写于 作者: 叶剑武

Merge branch 'lp_normalization' into 'master'

add lpnorm、mvnorm op for caffe, enhance biasadd、reshape op

See merge request !1224
......@@ -90,7 +90,9 @@ MemoryBlock MemoryOptimizer::CreateMemoryBlock(
if (shape.size() == 2) {
shape = {shape[0], 1, 1, shape[1]};
} else {
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input";
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input, "
<< "op name: " << op_def->name() << ", "
<< MakeString(shape);
}
OpenCLUtil::CalImage2DShape(shape, buffer_type, &image_shape);
block.set_x(image_shape[0]);
......
......@@ -62,34 +62,62 @@ void BiasAdd::AddBias(const OpContext *context,
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b * channels + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);
if (bias->dim_size() == 1) {
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
const index_t b_offset = b * channels;
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b_offset + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}, 0, batch, 1, 0, channels, 1);
} else {
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
const index_t b_offset = b * channels;
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b_offset + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[b * channels + c];
float32x4_t vbias = vdupq_n_f32(bias);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
}
}
}, 0, batch, 1, 0, channels, 1);
}, 0, batch, 1, 0, channels, 1);
}
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
......@@ -49,15 +49,18 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *bias = this->Input(1);
MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());
Tensor *output = this->Output(0);
if (input->dim_size() == 4 && has_data_format_) {
if (input->dim_size() == 4 && (has_data_format_
|| input->data_format() == DataFormat::NCHW)) { // NCHW
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or n*c for caffee.",
MakeString(bias->shape()));
bias_add_delegator_.Compute(context, input, bias, output);
} else {
} else { // NHWC
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1 or 2 dimensionals for caffee.",
bias->dim_size(), MakeString(bias->shape()));
// TODO(liyin): remove it and tranform bias to add (eltwise)
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
......@@ -70,16 +73,40 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
float *output_ptr = output->mutable_data<float>();
const std::vector<index_t> &shape = input->shape();
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
const index_t channels = *shape.rbegin();
for (index_t n = 0; n < fused_batch; ++n) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
if (bias->dim_size() == 1) {
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t n = start; n < end; n += step) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
}
}, 0, fused_batch, 1);
} else { // bias is 2d
const auto n = shape[0];
MACE_CHECK(n == bias->shape()[0]);
const index_t fused_hw = std::accumulate(
shape.begin() + 1, shape.end() - 1, 1, std::multiplies<index_t>());
const auto ch_size = bias->shape()[1];
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
auto offset = i * fused_hw;
auto bias_offset = i * ch_size;
for (index_t j = start1; j < end1; j += step1) {
index_t pos = (offset + i) * channels;
for (index_t c = 0; c < channels; ++c, ++pos) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[bias_offset + c];
}
}
}
}, 0, n, 1, 0, fused_hw, 1);
}
}
......@@ -109,21 +136,25 @@ class BiasAddOp<DeviceType::GPU, float> : public Operation {
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS);
// for const bias tensor
if (context->workspace()->GetTensor(operator_def_->input(1)) != nullptr) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS, "TransformFilter failed");
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
const Tensor *bias = this->Input(1);
MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
MACE_CHECK(input->dim_size() == 4 && has_data_format_,
"gpu only support biasadd for 4-dimensional NHWC format tensor");
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or 2-dimensional for caffee. ",
MakeString(bias->shape()));
return kernel_->Compute(context, input, bias, output);
}
......@@ -151,6 +182,10 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) {
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
LOG(INFO) << "BiasAdd only support cpu, has_data_format="
<< has_data_format
<< ", op->output_shape(0).dims_size()="
<< op->output_shape(0).dims_size();
return {DeviceType::CPU};
}
return {DeviceType::CPU, DeviceType::GPU};
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/lpnorm.h"
#endif // MACE_ENABLE_OPENCL
/**
* LpNormOp is a Normalization OP which support L1 and L2, which is a custom op
* of caffe (not exist in official caffe), please reference:
* https://github.com/freesouls/caffe/blob/master/src/caffe/layers/normalization_layer.cpp #noqa
*/
namespace mace {
namespace ops {
template<DeviceType D, typename T>
class LpNormOp;
template<>
class LpNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit LpNormOp(OpConstructContext *context)
: Operation(context),
p_(Operation::GetOptionalArg<int>("p", 2)),
axis_(Operation::GetOptionalArg<int>("axis", -1)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
if (axis_ < 0) {
axis_ += input->dim_size();
}
MACE_CHECK(axis_ < input->dim_size() && axis_ >= 0,
"The axis_ must be small than dim size");
const std::vector<index_t> &input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->Resize(input_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const auto *input_data = input->data<float>();
auto *output_data = output->mutable_data<float>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
auto outer_loop = std::accumulate(input_shape.begin(),
input_shape.begin() + axis_, 1,
std::multiplies<index_t>());
auto inner_loop = std::accumulate(input_shape.begin() + axis_,
input_shape.end(), 1,
std::multiplies<index_t>());
if (p_ == 1) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = std::abs(input_data[i]);
}
}, 0, input->size(), 1);
} else if (p_ == 2) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = input_data[i] * input_data[i];
}
}, 0, input->size(), 1);
} else {
LOG(FATAL) << "LpNorm's p should be 1 or 2, current p is: " << p_;
}
const float power = 1 / static_cast<float>(p_);
auto norm_buffer = context->device()->scratch_buffer();
norm_buffer->Rewind();
MACE_RETURN_IF_ERROR(norm_buffer->GrowSize(outer_loop * sizeof(float)));
float *norm_ptr = norm_buffer->mutable_data<float>();
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
auto output_data_base = output_data + inner_loop * i;
norm_ptr[i] = std::accumulate(output_data_base,
output_data_base + inner_loop, 0.0f);
norm_ptr[i] = std::pow(norm_ptr[i], power);
norm_ptr[i] += 1e-6;
}
}, 0, outer_loop, 1);
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] = input_data[offset + j] / norm_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
int p_;
int axis_;
};
#ifdef MACE_ENABLE_OPENCL
template<>
class LpNormOp<DeviceType::GPU, float> : public Operation {
public:
explicit LpNormOp(OpConstructContext *context)
: Operation(context) {
const auto p = Operation::GetOptionalArg<int>("p", 2);
const auto axis = Operation::GetOptionalArg<int>("axis", -1);
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::LpNormKernel>(p, axis);
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLLpNormKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL
void RegisterLpNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "LpNorm", LpNormOp,
DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "LpNorm", LpNormOp);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/mvnorm.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
// Mean-Variance Normalization (MVN)
template<DeviceType D, typename T>
class MVNormOp;
template<>
class MVNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit MVNormOp(OpConstructContext *context)
: Operation(context),
normalize_variance_(
Operation::GetOptionalArg<bool>("normalize_variance", true)),
across_channels_(
Operation::GetOptionalArg<bool>("across_channels", false)),
eps_(Operation::GetOptionalArg<float>("epsilon", 1e-9)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_CHECK(input->data_format() == DataFormat::NCHW,
"The MVN only suport NCHW");
const std::vector<index_t> &input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->Resize(input_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const auto *input_data = input->data<float>();
auto *output_data = output->mutable_data<float>();
const auto input_size = input->size();
const auto outer_loop =
across_channels_ ? input_shape[0] : input_shape[0] * input_shape[1];
const auto inner_loop = input_size / outer_loop;
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
Buffer mean_buffer(context->device()->allocator());
MACE_RETURN_IF_ERROR(mean_buffer.Allocate(outer_loop * sizeof(float)));
auto *mean_ptr = mean_buffer.mutable_data<float>();
// compute EX
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
const auto offset = inner_loop * i;
mean_ptr[i] = std::accumulate(input_data + offset,
input_data + offset + inner_loop, 0.0f);
mean_ptr[i] /= inner_loop;
}
}, 0, outer_loop, 1);
// compute (X - EX)
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] = input_data[offset + j] - mean_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
if (normalize_variance_) {
// compute (X - EX)^2
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = output_data[i] * output_data[i];
}
}, 0, input_size, 1);
auto mean_v_buffer = context->device()->scratch_buffer();
mean_v_buffer->Rewind();
MACE_RETURN_IF_ERROR(
mean_v_buffer->GrowSize(outer_loop * sizeof(float)));
float *mean_v_ptr = mean_v_buffer->mutable_data<float>();
// compute E((X - EX)^2)^0.5 + eps_
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
auto output_data_base = output_data + inner_loop * i;
mean_v_ptr[i] = std::accumulate(output_data_base,
output_data_base + inner_loop, 0.0f);
mean_v_ptr[i] = std::pow(mean_v_ptr[i] / inner_loop, 0.5f) + eps_;
}
}, 0, outer_loop, 1);
// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] =
(input_data[offset + j] - mean_ptr[i]) / mean_v_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
}
return MaceStatus::MACE_SUCCESS;
}
private:
bool normalize_variance_;
bool across_channels_;
float eps_;
};
#ifdef MACE_ENABLE_OPENCL
template<>
class MVNormOp<DeviceType::GPU, float> : public Operation {
public:
explicit MVNormOp(OpConstructContext *context) : Operation(context) {
auto normalize_variance =
Operation::GetOptionalArg<bool>("normalize_variance", true);
auto across_channels =
Operation::GetOptionalArg<bool>("across_channels", false);
auto eps = Operation::GetOptionalArg<float>("epsilon", 1e-9);
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::MVNormKernel>(
normalize_variance, across_channels, eps);
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLMVNormKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL
void RegisterMVNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "MVNorm", MVNormOp,
DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "MVNorm", MVNormOp);
}
} // namespace ops
} // namespace mace
......@@ -2,25 +2,27 @@
// Supported data types: half/float
__kernel void bias_add(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__private const int input_height,
__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
if (ch_blk >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int width = global_size_dim1;
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 bias_value = READ_IMAGET(bias, SAMPLER, (int2)(ch_blk, 0));
const int pos = mad24(ch_blk, width, width_idx);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
const int b_idx = select(0, hb_idx / input_height, input_height > 0);
DATA_TYPE4 bias_value = READ_IMAGET(bias, SAMPLER, (int2)(ch_blk, b_idx));
DATA_TYPE4 out = in + bias_value;
WRITE_IMAGET(output, (int2)(pos, hb), out);
WRITE_IMAGET(output, (int2)(pos, hb_idx), out);
}
#include <common.h>
DATA_TYPE4 compute_total(__read_only image2d_t input, const int hb_base,
const int chan_blks, const int width, const int height,
const int hb_idx, const int chan_blk_idx) {
DATA_TYPE4 total = 0.0f;
#if PARAM_AXIS == 1
const int wc_blks = mul24(width, chan_blks);
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int pos = 0; pos < wc_blks; ++pos) {
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
#if PARAM_P == 1
total += fabs(in_data);
#else
total = mad(in_data, in_data, total);
#endif
}
}
DATA_TYPE total_all = total.x + total.y + total.z + total.w;
total = (DATA_TYPE4){total_all, total_all, total_all, total_all};
#elif PARAM_AXIS == 2
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int w_idx = 0; w_idx < width; ++w_idx) {
int pos = mad24(chan_blk_idx, width, w_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
#if PARAM_P == 1
total = total + fabs(in_data);
#else
total = mad(in_data, in_data, total);
#endif
}
}
#elif PARAM_AXIS == 3
for (int w_idx = 0; w_idx < width; ++x) {
int pos = mad24(chan_blk_idx, width, w_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
#if PARAM_P == 1
total = total + fabs(in_data);
#else
total = mad(in_data, in_data, total);
#endif
}
#endif
return total;
}
__kernel void lpnorm(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int height,
__private const float eps,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
const int hb = global_size_dim2;
const int hb_base = mul24(hb_idx / height, height);
DATA_TYPE4 total = compute_total(input, hb_base, chan_blks, width, height,
hb_idx, chan_blk_idx);
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
#if PARAM_P == 1
in_data = in_data / (total + eps);
#else
in_data = in_data / (sqrt(total) + eps);
#endif
WRITE_IMAGET(output, (int2)(pos, hb_idx), in_data);
}
#include <common.h>
DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx,
const int hb_idx, const int chan_blks,
const int height, const int width) {
DATA_TYPE4 total = 0.0f;
DATA_TYPE4 mean = 0.0f;
const int hb_base = mul24(hb_idx / height, height);
const int wc_blks = mul24(width, chan_blks);
#ifdef ACROSS_CHANNELS
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int pos = 0; pos < wc_blks; ++pos) {
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
total += in_data;
}
}
DATA_TYPE total_value = total.x + total.y + total.z + total.w;
DATA_TYPE mean_value = total_value / (DATA_TYPE)(mul24(mul24(height, wc_blks), 4));
mean = (DATA_TYPE4){mean_value, mean_value, mean_value, mean_value};
#else
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int w_idx = 0; w_idx < width; ++w_idx) {
int pos = mad24(w_idx, chan_blks, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
total += in_data;
}
}
mean = total / mul24(height, width);
#endif
return mean;
}
__kernel void mvnorm_mean(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int height,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = compute_mean_image(input, width_idx,
hb_idx, chan_blks, height, width);
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data -= mean;
WRITE_IMAGET(output, (int2)(pos, hb_idx), in_data);
}
__kernel void mvnorm_vn_step1(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__write_only image2d_t mean_image, // E(X)
__write_only image2d_t square_image, // (X - EX)^2
__private const int height) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = compute_mean_image(input, width_idx,
hb_idx, chan_blks, height, width);
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data = in_data - mean;
DATA_TYPE4 pow_data = in_data * in_data;
if (hb_idx == 0 && width_idx == 0) {
WRITE_IMAGET(mean_image, (int2)(chan_blk_idx, 0), mean);
}
WRITE_IMAGET(square_image, (int2)(pos, hb_idx), pow_data);
}
__kernel void mvnorm_vn_step2(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t mean_image, // E(X)
__read_only image2d_t square_image, // (X - EX)^2
__private const int height,
__private const float eps,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = READ_IMAGET(mean_image, SAMPLER, (int2)(chan_blk_idx, 0));
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data = in_data - mean;
DATA_TYPE4 mean_v = compute_mean_image(square_image, width_idx,
hb_idx, chan_blks, height, width);
DATA_TYPE4 norm_data = in_data / (sqrt(mean_v) + eps);
WRITE_IMAGET(output, (int2)(pos, hb_idx), norm_data);
}
......@@ -56,6 +56,7 @@ MaceStatus BiasAddKernel::Compute(
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, static_cast<int>(bias->dim_size() > 1 ? height : 0));
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(bias->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/opencl/image/lpnorm.h"
#include <set>
#include <string>
#include <vector>
namespace mace {
namespace ops {
namespace opencl {
namespace image {
LpNormKernel::LpNormKernel(const int p, const int axis) : p_(p), axis_(axis) {
MACE_CHECK(p_ == 1 || p_ == 2, "Current p is: ", p);
}
MaceStatus LpNormKernel::Compute(OpContext *context,
const Tensor *input, Tensor *output) {
if (axis_ < 0) {
axis_ += input->dim_size();
}
MACE_CHECK(axis_ == 1 || axis_ == 2 || axis_ == 3,
"Current axis is: ", axis_);
const auto batch = input->dim(0);
const auto height = input->dim(1);
const auto width = input->dim(2);
const auto channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("lpnorm");
built_options.emplace("-Dlpnorm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
std::stringstream param_p;
param_p << "-DPARAM_P=" << p_;
built_options.emplace(param_p.str());
std::stringstream param_axis;
param_axis << "-DPARAM_AXIS=" << axis_;
built_options.emplace(param_axis.str());
MACE_RETURN_IF_ERROR(runtime->BuildKernel("lpnorm", kernel_name,
built_options, &kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
MACE_OUT_OF_RANGE_INIT(kernel_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int>(height));
kernel_.setArg(idx++, static_cast<float>(1e-6));
kernel_.setArg(idx++, *(output->opencl_image()));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_);
std::string tuning_key =
Concat("lpnorm_opencl_kernel", batch, height, width, channels, p_, axis_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_LPNORM_H_
#define MACE_OPS_OPENCL_IMAGE_LPNORM_H_
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
#include "mace/ops/opencl/lpnorm.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
class LpNormKernel : public OpenCLLpNormKernel {
public:
explicit LpNormKernel(const int p, const int axis);
~LpNormKernel() = default;
MaceStatus Compute(
OpContext *context, const Tensor *input, Tensor *output) override;
private:
int p_;
int axis_;
cl::Kernel kernel_;
uint32_t kwg_size_;
};
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_LPNORM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/opencl/image/mvnorm.h"
#include <memory>
#include <set>
#include <string>
#include <vector>
namespace mace {
namespace ops {
namespace opencl {
namespace image {
namespace {
MaceStatus BuildMVNKernel(OpenCLRuntime *runtime, cl::Kernel *kernel,
const char *kernel_name,
std::set<std::string> *built_options,
bool across_channel) {
std::stringstream micro_name;
micro_name << "-Dmvnorm=" << kernel_name;
built_options->emplace(micro_name.str());
built_options->emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options->emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
if (across_channel) {
built_options->emplace("-DACROSS_CHANNELS");
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("mvnorm", kernel_name,
*built_options, kernel));
return MaceStatus::MACE_SUCCESS;
}
std::unique_ptr<Image> CreateImage(
OpContext *context, const DataType dt,
const std::vector<index_t> &buffer_shape) {
std::unique_ptr<Image> image =
make_unique<Image>(context->device()->allocator());
std::vector<size_t> shape;
OpenCLUtil::CalImage2DShape(
buffer_shape, OpenCLBufferType::IN_OUT_CHANNEL, &shape);
MACE_CHECK(image->Allocate(shape, dt) == MaceStatus::MACE_SUCCESS);
VLOG(1) << "MVNormKernel::CreateImage allocate image_:" << MakeString(shape);
return image;
}
} // namespace
MVNormKernel::MVNormKernel(bool normalize_variance,
bool across_channels, float eps)
: normalize_variance_(normalize_variance),
across_channels_(across_channels),
eps_(eps) {}
void MVNormKernel::CheckImage(OpContext *context, const DataType dt,
const std::vector<index_t> &square_shape,
const std::vector<index_t> &mean_shape) {
if (square_image_ == nullptr) {
square_image_ = CreateImage(context, dt, square_shape);
}
if (mean_image_ == nullptr) {
mean_image_ = CreateImage(context, dt, mean_shape);
}
}
MaceStatus MVNormKernel::Compute(OpContext
*context,
const Tensor *input, Tensor
*output) {
const auto batch = input->dim(0);
const auto height = input->dim(1);
const auto width = input->dim(2);
const auto channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
if (normalize_variance_) {
const std::vector<index_t> &square_shape = input->buffer_shape();
const std::vector<index_t> mean_shape = {1, 1, 1, channels};
CheckImage(context, input->dtype(), square_shape, mean_shape);
// compute the (X - EX)^2
MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep1Kernel(
context, runtime, gws, input));
// compute the compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep2Kernel(
context, runtime, gws, input, output));
} else {
MACE_RETURN_IF_ERROR(ExecuteMeanNormKernel(
context, runtime, gws, input, output));
}
return
MaceStatus::MACE_SUCCESS;
}
MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output) {
const auto height = input->dim(1);
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step1_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_, "mvnorm_mean",
&built_options, across_channels_));
kwg_size_step1_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step1_));
}
MACE_OUT_OF_RANGE_INIT(kernel_step1_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_);
MACE_SET_3D_GWS_ARGS(kernel_step1_, gws);
kernel_step1_.setArg(idx++, *(input->opencl_image()));
kernel_step1_.setArg(idx++, static_cast<int>(height));
kernel_step1_.setArg(idx++, *(output->opencl_image()));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step1_);
std::string
tuning_key = Concat("mvnorm_mean_opencl_kernel", gws[0], gws[1], gws[2],
normalize_variance_, across_channels_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
// The first step of compute Variance Norm, compute the (X - EX)^2
// store them into the square_image_
MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel(
OpContext *context, OpenCLRuntime *runtime,
const uint32_t (&gws)[3], const Tensor *input) {
const auto height = input->dim(1);
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step1_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_,
"mvnorm_vn_step1",
&built_options, across_channels_));
kwg_size_step1_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step1_));
}
MACE_OUT_OF_RANGE_INIT(kernel_step1_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_);
MACE_SET_3D_GWS_ARGS(kernel_step1_, gws);
kernel_step1_.setArg(idx++, *(input->opencl_image()));
cl::Image *mean_image = static_cast<cl::Image *>(mean_image_->buffer());
kernel_step1_.setArg(idx++, *mean_image);
cl::Image *square_image = static_cast<cl::Image *>(square_image_->buffer());
kernel_step1_.setArg(idx++, *square_image);
kernel_step1_.setArg(idx++, static_cast<int>(height));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step1_);
std::string
tuning_key = Concat("mvnorm_v_step1_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
across_channels_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
// The second step of compute Variance Norm, read the (X - EX)^2 from
// square_image_ and compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel(
OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3],
const Tensor *input, Tensor *output) {
const auto height = input->dim(1);
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step2_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step2_,
"mvnorm_vn_step2",
&built_options, across_channels_));
kwg_size_step2_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step2_));
}
MACE_OUT_OF_RANGE_INIT(kernel_step2_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step2_);
MACE_SET_3D_GWS_ARGS(kernel_step2_, gws);
kernel_step2_.setArg(idx++, *(input->opencl_image()));
cl::Image *mean_image = static_cast<cl::Image *>(mean_image_->buffer());
kernel_step2_.setArg(idx++, *mean_image);
cl::Image *square_image = static_cast<cl::Image *>(square_image_->buffer());
kernel_step2_.setArg(idx++, *square_image);
kernel_step2_.setArg(idx++, static_cast<int>(height));
kernel_step2_.setArg(idx++, static_cast<float>(eps_));
kernel_step2_.setArg(idx++, *(output->opencl_image()));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step2_);
std::string
tuning_key = Concat("mvnorm_v_step2_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
across_channels_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step2_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_MVNORM_H_
#define MACE_OPS_OPENCL_IMAGE_MVNORM_H_
#include <memory>
#include <vector>
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
#include "mace/ops/opencl/mvnorm.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
class MVNormKernel : public OpenCLMVNormKernel {
public:
explicit MVNormKernel(bool normalize_variance_,
bool across_channels, float eps);
~MVNormKernel() = default;
MaceStatus Compute(
OpContext *context, const Tensor *input, Tensor *output) override;
private:
void CheckImage(OpContext *context, const DataType dt,
const std::vector<index_t> &square_shape,
const std::vector<index_t> &mean_shape);
MaceStatus ExecuteMeanNormKernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output);
MaceStatus ExecuteVarianceNormStep1Kernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input);
MaceStatus ExecuteVarianceNormStep2Kernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output);
private:
bool normalize_variance_;
bool across_channels_;
float eps_;
cl::Kernel kernel_step1_;
uint32_t kwg_size_step1_;
cl::Kernel kernel_step2_;
uint32_t kwg_size_step2_;
// the cache of (X - EX)^2
std::unique_ptr<Image> square_image_;
// the cache of EX
std::unique_ptr<Image> mean_image_;
};
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_MVNORM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_LPNORM_H_
#define MACE_OPS_OPENCL_LPNORM_H_
#include "mace/public/mace.h"
#include "mace/utils/math.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLLpNormKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLLpNormKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_LPNORM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_MVNORM_H_
#define MACE_OPS_OPENCL_MVNORM_H_
#include "mace/public/mace.h"
#include "mace/utils/math.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLMVNormKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLMVNormKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_MVNORM_H_
......@@ -56,12 +56,13 @@ void BiasAdd::AddBias(const OpContext *context,
const index_t width = output->dim(3);
const index_t image_size = height * width;
auto bias_b = bias->dim_size() == 1 ? 0 : bias->shape()[1];
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t offset = (b * channels + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
const float bias = bias_data[bias_b * channels + c];
for (index_t i = 0; i < image_size; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
......
......@@ -46,8 +46,10 @@ extern void RegisterDelay(OpRegistryBase *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterKaldiBatchNorm(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterLpNorm(OpRegistryBase *op_registry);
extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterMVNorm(OpRegistryBase *op_registry);
extern void RegisterOneHot(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPadContext(OpRegistryBase *op_registry);
......@@ -121,8 +123,10 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterInferConv2dShape(this);
ops::RegisterKaldiBatchNorm(this);
ops::RegisterLocalResponseNorm(this);
ops::RegisterLpNorm(this);
ops::RegisterLSTMNonlinear(this);
ops::RegisterMatMul(this);
ops::RegisterMVNorm(this);
ops::RegisterOneHot(this);
ops::RegisterPad(this);
ops::RegisterPadContext(this);
......
......@@ -159,8 +159,10 @@ void RegisterReshape(OpRegistryBase *op_registry) {
auto tensor_shape_info = context->tensor_shape_info();
const std::string &input_0 = op->input(0);
if (4 == op->output_shape(0).dims_size() &&
4 == tensor_shape_info->at(input_0).size()) {
const auto out_dims_size =
op->output_shape(0).dims_size();
if (4 == tensor_shape_info->at(input_0).size()
&& (out_dims_size == 4 || out_dims_size == 2)) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
......
......@@ -82,12 +82,13 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation {
index_t batch_stride = class_size;
index_t batch_size = batch_stride * input->dim(0);
Buffer cache_buffer(context->device()->allocator());
MACE_RETURN_IF_ERROR(cache_buffer.Allocate(hw_size * sizeof(float)));
auto cache_buffer = context->device()->scratch_buffer();
cache_buffer->Rewind();
MACE_RETURN_IF_ERROR(cache_buffer->GrowSize(hw_size * sizeof(float)));
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
float std_lowest = std::numeric_limits<float>::lowest();
float *cache_ptr = cache_buffer.mutable_data<float>();
float *cache_ptr = cache_buffer->mutable_data<float>();
for (index_t b_offset = 0;
b_offset < batch_size; b_offset += batch_stride) {
......
......@@ -84,7 +84,7 @@ def encrypt_opencl_codegen(cl_kernel_dir, output_path):
for file_name in os.listdir(cl_kernel_dir):
file_path = os.path.join(cl_kernel_dir, file_name)
module_key = get_module_key(file_name)
if len(module_key) > 0:
if module_key is not None and len(module_key) > 0:
with open(file_path, "r") as f:
code_str = ""
headers = []
......
......@@ -63,6 +63,8 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/space_to_depth.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/sqrdiff_mean.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/winograd_transform.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/lpnorm.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/mvnorm.cl"))
python_bin_path = repository_ctx.which("python")
......
......@@ -62,11 +62,60 @@ void BiasAddSimple() {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
template <DeviceType D>
void BiasAddSimple2D() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {2, 3, 1, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
net.AddInputFromArray<D, float>("Bias", {2, 2},
{0.1f, 0.2f, 0.3f, 0.4f}, true);
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.AddIntArg("has_data_format", 1)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) {
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
MACE_NOT_IMPLEMENTED;
}
// Check
auto expected = net.CreateTensor<float>(
{2, 3, 1, 2},
{5.1, 5.2, 7.1, 7.2, 9.1, 9.2, 11.3, 11.4, 13.3, 13.4, 15.3, 15.4});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(BiasAddOpTest, BiasAddSimpleCPU) { BiasAddSimple<DeviceType::CPU>(); }
TEST_F(BiasAddOpTest, BiasAddSimpleOPENCL) { BiasAddSimple<DeviceType::GPU>(); }
TEST_F(BiasAddOpTest, BiasAddSimple2DCPU) {
BiasAddSimple2D<DeviceType::CPU>();
}
TEST_F(BiasAddOpTest, BiasAddSimpleOPENCL) {
BiasAddSimple<DeviceType::GPU>();
}
TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
// generate random input
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class LpNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestLpNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int p,
const int axis,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<D, T>(MakeString("Input"), input_shape, input);
if (D == DeviceType::GPU) {
net.TransformDataFormat<GPU, float>(
"Input", DataFormat::NCHW, "InputNHWC", DataFormat::NHWC);
}
OpDefBuilder("LpNorm", "LpNormTest")
.Input(D == DeviceType::CPU ? "Input" : "InputNHWC")
.AddIntArg("p", p)
.AddIntArg("axis", axis)
.Output(D == DeviceType::CPU ? "Output" : "OutputNHWC")
.Finalize(net.NewOperatorDef());
net.RunOp(D);
if (D == DeviceType::GPU) {
net.TransformDataFormat<GPU, float>(
"OutputNHWC", DataFormat::NHWC, "Output", DataFormat::NCHW);
}
net.AddInputFromArray<D, T>("ExpectedOutput", input_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(LpNormOpTest, SimpleTestFabs) {
TestLpNorm<DeviceType::CPU, float>(
{1, 8, 1, 2}, // NCHW
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
1, 1,
{0.00735294, 0.0147059, 0.0220588, 0.0294118,
0.0367647, 0.0441176, 0.0514706, 0.0588235,
0.0661765, 0.0735294, 0.0808824, 0.0882353,
0.0955882, 0.102941, 0.110294, 0.117647});
}
TEST_F(LpNormOpTest, SimpleTestSquare) {
TestLpNorm<DeviceType::CPU, float>(
{1, 8, 1, 2}, // NCHW
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
2, 1,
{0.0258544, 0.0517088, 0.0775632, 0.103418,
0.129272, 0.155126, 0.180981, 0.206835,
0.232689, 0.258544, 0.284398, 0.310253,
0.336107, 0.361961, 0.387816, 0.41367});
}
TEST_F(LpNormOpTest, SimpleTestPSquare2) {
TestLpNorm<DeviceType::CPU, float>(
{1, 8, 1, 2}, // NCHW
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
2, 2,
{0.447214, 0.894427, 0.600000, 0.800000,
0.640184, 0.768221, 0.658505, 0.752577,
0.668965, 0.743294, 0.675725, 0.737154,
0.680451, 0.732793, 0.683941, 0.729537});
}
TEST_F(LpNormOpTest, SimpleTestFabsOpenCL) {
TestLpNorm<DeviceType::GPU, float>(
{1, 8, 1, 2}, // NCHW
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
1, 1,
{0.00735294, 0.0147059, 0.0220588, 0.0294118,
0.0367647, 0.0441176, 0.0514706, 0.0588235,
0.0661765, 0.0735294, 0.0808824, 0.0882353,
0.0955882, 0.102941, 0.110294, 0.117647});
}
TEST_F(LpNormOpTest, SimpleTestSquareOpenCL) {
TestLpNorm<DeviceType::GPU, float>(
{1, 8, 1, 2}, // NCHW
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
2, 1,
{0.0258544, 0.0517088, 0.0775632, 0.103418,
0.129272, 0.155126, 0.180981, 0.206835,
0.232689, 0.258544, 0.284398, 0.310253,
0.336107, 0.361961, 0.387816, 0.41367});
}
TEST_F(LpNormOpTest, SimpleTestSquareOpenCL2) {
TestLpNorm<DeviceType::GPU, float>(
{1, 8, 1, 2}, // NCHW
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
2, 2,
{0.447214, 0.894427, 0.600000, 0.800000,
0.640184, 0.768221, 0.658505, 0.752577,
0.668965, 0.743294, 0.675725, 0.737154,
0.680451, 0.732793, 0.683941, 0.729537});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class MVNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestMVNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
bool normalize_variance,
bool across_channels,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<D, T>(MakeString("Input"), input_shape, input);
if (D == DeviceType::CPU) {
net.TransformDataFormat<CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
}
OpDefBuilder("MVNorm", "MVNormTest")
.Input(D == DeviceType::CPU ? "InputNCHW" : "Input")
.AddIntArg("normalize_variance", normalize_variance)
.AddIntArg("across_channels", across_channels)
.Output(D == DeviceType::CPU ? "OutputNCHW" : "Output")
.Finalize(net.NewOperatorDef());
net.RunOp(D);
if (D == DeviceType::CPU) {
net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
}
net.AddInputFromArray<D, T>("ExpectedOutput", input_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(MVNormOpTest, SimpleTestMean) {
TestMVNorm<DeviceType::CPU, float>(
{1, 1, 5, 12},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
false, true,
{-9.5, -8.5, -7.5, -6.5, -5.5,
-4.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, -7.5, -6.5, -5.5,
-4.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, -5.5,
-4.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, 4.5,
5.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, 4.5,
5.5, 6.5, 7.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, 4.5,
5.5, 6.5, 7.5, 8.5, 9.5});
}
TEST_F(MVNormOpTest, SimpleTestVariance) {
TestMVNorm<DeviceType::CPU, float>(
{1, 1, 5, 12},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
true, true,
{-2.1287, -1.90463, -1.68056, -1.45648, -1.23241,
-1.00833, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, -1.68056, -1.45648, -1.23241,
-1.00833, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, -1.23241,
-1.00833, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, 1.00833,
1.23241, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, 1.00833,
1.23241, 1.45648, 1.68056, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, 1.00833,
1.23241, 1.45648, 1.68056, 1.90463, 2.1287});
}
TEST_F(MVNormOpTest, SimpleTestVariance2) {
TestMVNorm<DeviceType::CPU, float>(
{1, 1, 1, 16},
{-0.63984936, -0.5024374 , -2.1083345, 2.6399455,
-0.63989604, -0.63280314, 2.905462, 1.0263479,
-0.502281 , -0.58158046, -0.5358325 , -0.50097936,
1.2043145 , -0.53840625, -0.50652033, -0.48295242},
true, true,
{-0.485057, -0.376699, -1.643057, 2.101283,
-0.485094, -0.479501, 2.310661, 0.828852,
-0.376575, -0.439108, -0.403033, -0.375549,
0.969191, -0.405062, -0.379918, -0.361333});
}
TEST_F(MVNormOpTest, SimpleTestMeanOpenCL) {
TestMVNorm<DeviceType::GPU, float>(
{1, 1, 5, 12},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
false, true,
{-9.5, -8.5, -7.5, -6.5, -5.5,
-4.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, -7.5, -6.5, -5.5,
-4.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, -5.5,
-4.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, 4.5,
5.5, -3.5, -2.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, 4.5,
5.5, 6.5, 7.5, -1.5, -0.5,
0.5, 1.5, 2.5, 3.5, 4.5,
5.5, 6.5, 7.5, 8.5, 9.5});
}
TEST_F(MVNormOpTest, SimpleTestVarianceOpenCL) {
TestMVNorm<DeviceType::GPU, float>(
{1, 1, 5, 12},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
true, true,
{-2.1287, -1.90463, -1.68056, -1.45648, -1.23241,
-1.00833, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, -1.68056, -1.45648, -1.23241,
-1.00833, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, -1.23241,
-1.00833, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, 1.00833,
1.23241, -0.784259, -0.560185, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, 1.00833,
1.23241, 1.45648, 1.68056, -0.336111, -0.112037,
0.112037, 0.336111, 0.560185, 0.784259, 1.00833,
1.23241, 1.45648, 1.68056, 1.90463, 2.1287});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -124,7 +124,8 @@ class OpsTestNet {
input->Resize(shape);
Tensor::MappingGuard input_mapper(input);
T *input_data = input->mutable_data<T>();
MACE_CHECK(static_cast<size_t>(input->size()) == data.size());
MACE_CHECK(static_cast<size_t>(input->size()) == data.size(),
input->size(), " VS ", data.size());
memcpy(input_data, data.data(), data.size() * sizeof(T));
input->SetScale(scale);
input->SetZeroPoint(zero_point);
......
......@@ -544,6 +544,7 @@ message LayerParameter {
optional VideoDataParameter video_data_param = 207;
optional WindowDataParameter window_data_param = 129;
optional ShuffleChannelParameter shuffle_channel_param = 164;
optional L2NormalizationParameter l2normalization_param = 208;
}
// Message that stores parameters used to apply transformation
......@@ -1934,3 +1935,7 @@ message PReLUParameter {
message ShuffleChannelParameter {
optional uint32 group = 1[default = 1]; // The number of group
}
message L2NormalizationParameter {
optional int32 axis = 1 [default = 1];
}
......@@ -748,7 +748,6 @@ class DeviceWrapper:
) if YAMLKeyword.dockerfile_path \
in model_config \
else ("third_party/caffe", "lastest")
sh_commands.validate_model(
abi=target_abi,
device=self,
......
......@@ -152,6 +152,8 @@ MaceSupportedOps = [
'Transpose',
'Cumsum',
'Tile',
'LpNorm',
'MVNorm',
]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
......@@ -169,7 +171,9 @@ MaceFixedDataFormatOps = [MaceOp.BatchNorm,
MaceOp.ResizeBilinear,
MaceOp.ResizeNearestNeighbor,
MaceOp.SpaceToBatchND,
MaceOp.SpaceToDepth]
MaceOp.SpaceToDepth,
MaceOp.LpNorm,
MaceOp.MVNorm]
MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.AddN,
......@@ -180,6 +184,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.Eltwise,
MaceOp.Pad,
MaceOp.Reduce,
MaceOp.Reshape,
MaceOp.Softmax,
MaceOp.Split,
MaceOp.Squeeze,
......@@ -264,6 +269,9 @@ class MaceKeyword(object):
mace_reverse_str = 'reverse'
mace_const_data_num_arg_str = 'const_data_num'
mace_coeff_str = 'coeff'
mace_p_str = 'p'
mace_nor_var_str = 'normalize_variance'
mace_across_ch_str = 'across_channels'
class TransformerRule(Enum):
......
......@@ -191,6 +191,10 @@ class CaffeConverter(base_converter.ConverterInterface):
'Flatten': self.convert_flatten,
'PriorBox': self.convert_prior_box,
'Reshape': self.convert_reshape,
'L2Normalization': self.convert_lpnorm,
'L1Normalization': self.convert_lpnorm,
'MVN': self.convert_MVN,
'Bias': self.convert_Bias,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -815,3 +819,57 @@ class CaffeConverter(base_converter.ConverterInterface):
num_axes_arg.i = -1
if param.HasField('num_axes'):
num_axes_arg.i = param.num_axes
def convert_lpnorm(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.l2normalization_param
op.type = MaceOp.LpNorm.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = -1
if param.HasField('axis'):
axis_arg.i = param.axis
p_arg = op.arg.add()
p_arg.name = MaceKeyword.mace_p_str
if caffe_op.type == 'L1Normalization':
p_arg.i = 1
elif caffe_op.type == 'L2Normalization':
p_arg.i = 2
else:
mace_check(False, "Can not support %s" % caffe_op.type)
def convert_MVN(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.mvn_param
op.type = MaceOp.MVNorm.name
if param.HasField('normalize_variance'):
nv_arg = op.arg.add()
nv_arg.name = MaceKeyword.mace_nor_var_str
nv_arg.i = param.normalize_variance
if param.HasField('across_channels'):
across_ch_arg = op.arg.add()
across_ch_arg.name = MaceKeyword.mace_across_ch_str
across_ch_arg.i = param.across_channels
if param.HasField('eps'):
eps_arg = op.arg.add()
eps_arg.name = MaceKeyword.mace_epsilon_str
eps_arg.f = param.eps
def convert_Bias(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.BiasAdd.name
param = caffe_op.layer.bias_param
mace_check(not param.axis or param.axis == 0 or param.axis == 1,
"BiasAdd only support axis with 0 or 1.")
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
if param.axis is not None:
mace_check(param.axis == 0 or param.axis == 1,
"BiasAdd only support axis with 0 or 1.")
axis_arg.i = param.axis
......@@ -113,7 +113,7 @@ OnnxSupportedOps = [
# 'Log',
'LogSoftmax',
# 'Loop',
# 'LpNormalization',
'LpNormalization',
# 'LpPool',
'MatMul',
'Max',
......@@ -353,6 +353,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.ImageScaler.name: self.convert_imagescaler,
OnnxOpType.LeakyRelu.name: self.convert_activation,
OnnxOpType.LogSoftmax.name: self.convert_softmax,
OnnxOpType.LpNormalization: self.convert_lpnormalization,
OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear,
OnnxOpType.DynamicLSTM.name: self.convert_dynamic_lstm,
OnnxOpType.Max.name: self.convert_eltwise,
......@@ -1435,6 +1436,18 @@ class OnnxConverter(base_converter.ConverterInterface):
use_log_arg.name = 'use_log'
use_log_arg.i = 1
def convert_lpnormalization(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.LpNorm.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = node.attrs.get('axis', -1)
p_arg = op.arg.add()
p_arg.name = MaceKeyword.mace_p_str
p_arg.i = node.attrs.get('p', 2)
def convert_splice(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Splice.name
......@@ -1565,7 +1578,6 @@ class OnnxConverter(base_converter.ConverterInterface):
op.input.append(size_tensor_name)
else:
op.type = MaceOp.ResizeBilinear.name
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_resize_size_str
size_arg.ints.extend(output_size.tolist())
......
......@@ -52,6 +52,8 @@ class ShapeInference(object):
MaceOp.PriorBox.name: self.infer_shape_prior_box,
MaceOp.Reshape.name: self.infer_shape_reshape,
MaceOp.ResizeBilinear.name: self.infer_shape_resize_bilinear,
MaceOp.LpNorm.name: self.infer_shape_general,
MaceOp.MVNorm.name: self.infer_shape_general,
}
self._net = net
......@@ -206,7 +208,7 @@ class ShapeInference(object):
def infer_shape_slice(self, op):
output_shape = self._output_shape_cache[op.input[0]]
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
output_shape[axis] /= len(op.output)
output_shape[axis] = (int)(output_shape[axis] / len(op.output))
output_shapes = []
for _ in op.output:
output_shapes.append(output_shape)
......
......@@ -1376,6 +1376,16 @@ class Transformer(base_converter.ConverterInterface):
out_shape.dims for out_shape in op.output_shape]))
return False
def is_transposable_data_format_ops(self, op):
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
out_dims_len = len(op.output_shape[0].dims)
if len(input_op.output_shape[0].dims) != 4 \
or (out_dims_len != 4 and out_dims_len != 2):
print("In this model, reshape is not transposable op.")
return False
return op.type in MaceTransposableDataFormatOps
def update_data_format(self):
print("update data format")
net = self._model
......@@ -1387,7 +1397,7 @@ class Transformer(base_converter.ConverterInterface):
df_arg.name = MaceKeyword.mace_data_format_str
if op.type in MaceFixedDataFormatOps:
df_arg.i = DataFormat.AUTO.value
elif op.type in MaceTransposableDataFormatOps:
elif self.is_transposable_data_format_ops(op):
input_df = DataFormat.AUTO.value
for input_tensor in op.input:
if input_tensor in self._consts:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册