提交 8911cffb 编写于 作者: L liutuo

add splice affine pnorm sumgroup for kaldi-model

support kaldi-onnx model convert
support validate output data from url
上级 f2f05c0d
...@@ -258,11 +258,12 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -258,11 +258,12 @@ MaceStatus Gemv::Compute(const OpContext *context,
++rhs_ptr; ++rhs_ptr;
} }
float32x4_t vbias = vdupq_n_f32(0);
if (bias) { if (bias) {
float32x4_t vbias = vdupq_n_f32(0);
vbias = vld1q_f32(bias_data + h_start); vbias = vld1q_f32(bias_data + h_start);
vo = vaddq_f32(vo, vbias);
} }
vo = vaddq_f32(vo, vbias);
vst1q_f32(ret_ptr, vo); vst1q_f32(ret_ptr, vo);
} else { // h_block_len < 4 } else { // h_block_len < 4
#endif // MACE_GEMV_UNROLL #endif // MACE_GEMV_UNROLL
......
...@@ -376,11 +376,11 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context, ...@@ -376,11 +376,11 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
++rhs_ptr; ++rhs_ptr;
} }
int32x4_t vbias = vdupq_n_s32(0);
if (bias) { if (bias) {
int32x4_t vbias = vdupq_n_s32(0);
vbias = vld1q_s32(bias_data + h_offset); vbias = vld1q_s32(bias_data + h_offset);
vo = vaddq_s32(vo, vbias);
} }
vo = vaddq_s32(vo, vbias);
if (is_output_type_uint8) { if (is_output_type_uint8) {
int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier); int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier);
......
...@@ -25,6 +25,7 @@ namespace ops { ...@@ -25,6 +25,7 @@ namespace ops {
template <DeviceType D, class T> template <DeviceType D, class T>
class LSTMCellOp; class LSTMCellOp;
#ifdef MACE_ENABLE_OPENCL
template <typename T> template <typename T>
class LSTMCellOp<DeviceType::GPU, T> : public Operation { class LSTMCellOp<DeviceType::GPU, T> : public Operation {
public: public:
...@@ -88,6 +89,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation { ...@@ -88,6 +89,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation {
MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL); MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL);
MACE_OP_OUTPUT_TAGS(CELL, OUTPUT); MACE_OP_OUTPUT_TAGS(CELL, OUTPUT);
}; };
#endif
void RegisterLSTMCell(OpRegistryBase *op_registry) { void RegisterLSTMCell(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "LSTMCell", LSTMCellOp, MACE_REGISTER_OP(op_registry, "LSTMCell", LSTMCellOp,
......
...@@ -43,6 +43,7 @@ extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); ...@@ -43,6 +43,7 @@ extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPNorm(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistryBase *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry); extern void RegisterReduce(OpRegistryBase *op_registry);
extern void RegisterPriorBox(OpRegistryBase *op_registry); extern void RegisterPriorBox(OpRegistryBase *op_registry);
...@@ -53,14 +54,19 @@ extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry); ...@@ -53,14 +54,19 @@ extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry);
extern void RegisterReverse(OpRegistryBase *op_registry); extern void RegisterReverse(OpRegistryBase *op_registry);
extern void RegisterScalarMath(OpRegistryBase *op_registry); extern void RegisterScalarMath(OpRegistryBase *op_registry);
extern void RegisterShape(OpRegistryBase *op_registry); extern void RegisterShape(OpRegistryBase *op_registry);
extern void RegisterSlice(OpRegistryBase *op_registry);
extern void RegisterSoftmax(OpRegistryBase *op_registry); extern void RegisterSoftmax(OpRegistryBase *op_registry);
extern void RegisterSpaceToBatchND(OpRegistryBase *op_registry); extern void RegisterSpaceToBatchND(OpRegistryBase *op_registry);
extern void RegisterSpaceToDepth(OpRegistryBase *op_registry); extern void RegisterSpaceToDepth(OpRegistryBase *op_registry);
extern void RegisterSplice(OpRegistryBase *op_registry);
extern void RegisterSplit(OpRegistryBase *op_registry); extern void RegisterSplit(OpRegistryBase *op_registry);
extern void RegisterSqrDiffMean(OpRegistryBase *op_registry); extern void RegisterSqrDiffMean(OpRegistryBase *op_registry);
extern void RegisterSqueeze(OpRegistryBase *op_registry); extern void RegisterSqueeze(OpRegistryBase *op_registry);
extern void RegisterStack(OpRegistryBase *op_registry); extern void RegisterStack(OpRegistryBase *op_registry);
extern void RegisterStridedSlice(OpRegistryBase *op_registry); extern void RegisterStridedSlice(OpRegistryBase *op_registry);
extern void RegisterSumGroup(OpRegistryBase *op_registry);
extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry);
extern void RegisterTimeOffset(OpRegistryBase *op_registry);
extern void RegisterTranspose(OpRegistryBase *op_registry); extern void RegisterTranspose(OpRegistryBase *op_registry);
extern void RegisterUnstack(OpRegistryBase *op_registry); extern void RegisterUnstack(OpRegistryBase *op_registry);
...@@ -103,6 +109,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ...@@ -103,6 +109,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterLocalResponseNorm(this); ops::RegisterLocalResponseNorm(this);
ops::RegisterMatMul(this); ops::RegisterMatMul(this);
ops::RegisterPad(this); ops::RegisterPad(this);
ops::RegisterPNorm(this);
ops::RegisterPooling(this); ops::RegisterPooling(this);
ops::RegisterReduce(this); ops::RegisterReduce(this);
ops::RegisterPriorBox(this); ops::RegisterPriorBox(this);
...@@ -113,14 +120,19 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ...@@ -113,14 +120,19 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterReverse(this); ops::RegisterReverse(this);
ops::RegisterScalarMath(this); ops::RegisterScalarMath(this);
ops::RegisterShape(this); ops::RegisterShape(this);
ops::RegisterSlice(this);
ops::RegisterSoftmax(this); ops::RegisterSoftmax(this);
ops::RegisterSpaceToBatchND(this); ops::RegisterSpaceToBatchND(this);
ops::RegisterSpaceToDepth(this); ops::RegisterSpaceToDepth(this);
ops::RegisterSplice(this);
ops::RegisterSplit(this); ops::RegisterSplit(this);
ops::RegisterStack(this); ops::RegisterStack(this);
ops::RegisterStridedSlice(this); ops::RegisterStridedSlice(this);
ops::RegisterSqrDiffMean(this); ops::RegisterSqrDiffMean(this);
ops::RegisterSqueeze(this); ops::RegisterSqueeze(this);
ops::RegisterSumGroup(this);
ops::RegisterTargetRMSNorm(this);
ops::RegisterTimeOffset(this);
ops::RegisterTranspose(this); ops::RegisterTranspose(this);
ops::RegisterUnstack(this); ops::RegisterUnstack(this);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for PNormComponent in Kaldi.
// The input-dim must be dividable by output-dim.
// The output will be divided to output-dim group,
// so input-dim should be dividable by output-dim.
// For each row:
// p is 0: output[i] = sum(abs(input[i*group + j]) > 0)
// p is 1: output[i] = sum(abs(input[i*group + j]))
// p is 2: output[i] = sqrt(sum(input[i * group + j] * input[i * group + j])),
// for j = (0 : group - 1)
// p's default value is 2.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class PNormOp;
template <typename T>
class PNormOp<DeviceType::CPU, T> : public Operation {
public:
explicit PNormOp(OpConstructContext *context)
: Operation(context),
p_(Operation::GetOptionalArg<int>("p", 2)),
output_dim_(Operation::GetOptionalArg<int>("output_dim", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t dim_size = input_shape.size();
MACE_CHECK(dim_size >= 1, "PNorm only supports input dim size >= 1");
std::vector<index_t> output_shape(input_shape);
const index_t input_dim = input_shape[dim_size -1];
MACE_CHECK(output_dim_ > 0,
"Output dim should be greater than zero.");
MACE_CHECK(input_dim % output_dim_ == 0 && output_dim_ < input_dim,
"PNorm's input dim should be a multiple of output dim.");
const index_t group_size = input_dim / output_dim_;
output_shape[dim_size - 1] = output_dim_;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
const index_t bh =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
std::multiplies<index_t>());
if (p_ == 0) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim_; ++j) {
const T *in_base = input_data + i * input_dim + j * group_size;
T *out_base = output_data + i * output_dim_;
T temp_result = 0;
for (index_t g = 0; g < group_size; ++g) {
T value =
(std::fabs(in_base[g])
> std::numeric_limits<float>::epsilon()) ? 1.0f : 0.0f;
temp_result += value;
}
out_base[j] = temp_result;
}
}
} else if (p_ == 1) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim_; ++j) {
const T *in_base = input_data + i * input_dim + j * group_size;
T *out_base = output_data + i * output_dim_;
T temp_result = 0;
for (index_t g = 0; g < group_size; ++g) {
temp_result += std::abs(in_base[g]);;
}
out_base[j] = temp_result;
}
}
} else if (p_ == 2) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim_; ++j) {
const T *in_base = input_data + i * input_dim + j * group_size;
T *out_base = output_data + i * output_dim_;
T temp_result = 0;
for (index_t g = 0; g < group_size; ++g) {
temp_result += in_base[g] * in_base[g];
}
out_base[j] = std::sqrt(temp_result);
}
}
} else {
LOG(FATAL) << "PNorm's p should be 0, 1 or 2, here p is: " << p_;
}
return MaceStatus::MACE_SUCCESS;
}
private:
int p_;
int output_dim_;
};
void RegisterPNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "PNorm", PNormOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void PNormBenchmark(int iters, int n, int h, int w, int p, int ow) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {n, h, w});
OpDefBuilder("PNorm", "PNormBM")
.Input("Input")
.AddIntArg("p", p)
.AddIntArg("output_dim", ow)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_PNORM_MACRO(N, H, W, P, OW, TYPE, DEVICE) \
static void \
MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
PNormBenchmark<DEVICE, TYPE>(iters, N, H, W, P, OW); \
} \
MACE_BENCHMARK( \
MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE)
#define MACE_BM_PNORM(N, H, W, P, OW) \
MACE_BM_PNORM_MACRO(N, H, W, P, OW, float, CPU);
MACE_BM_PNORM(1, 10, 256, 0, 128);
MACE_BM_PNORM(1, 20, 128, 1, 64);
MACE_BM_PNORM(1, 10, 128, 2, 64);
MACE_BM_PNORM(1, 16, 256, 0, 128);
MACE_BM_PNORM(1, 32, 128, 1, 64);
MACE_BM_PNORM(1, 10, 512, 2, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class PNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestPNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int p,
const int output_dim,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("PNorm", "PNormTest")
.Input("Input")
.AddIntArg("p", p)
.AddIntArg("output_dim", output_dim)
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(PNormOpTest, SimpleTest) {
TestPNorm<DeviceType::CPU, float>(
{1, 5, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
2, 5,
{1, 5, 5},
{2.236067977, 5, 7.810249676, 10.630145813, 13.453624047,
5, 7.810249676, 10.630145813, 13.453624047, 16.278820596,
7.810249676, 10.630145813, 13.453624047, 16.278820596, 19.104973175,
10.630145813, 13.453624047, 16.278820596, 19.104973175, 21.931712199,
13.453624047, 16.278820596, 19.104973175, 21.931712199, 24.758836806});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SliceOp;
template <typename T>
class SliceOp<DeviceType::CPU, T> : public Operation {
public:
explicit SliceOp(OpConstructContext *context)
: Operation(context),
axes_(Operation::GetRepeatedArgs<int>("axes")),
starts_(Operation::GetRepeatedArgs<int>("starts")),
ends_(Operation::GetRepeatedArgs<int>("ends")) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const index_t rank = input->dim_size();
MACE_CHECK(rank >= 1)
<< "The input dim size should >= 1";
MACE_CHECK(starts_.size() == 1 && ends_.size() == 1 && axes_.size() == 1,
"only support slicing at one axis.");
MACE_CHECK(axes_[0] == -1 || axes_[0] == rank - 1,
"only support slicing at the last axis.");
const index_t input_dim = input->dim(rank - 1);
const index_t offset = starts_[0];
const index_t output_dim = ends_[0] - starts_[0];
MACE_CHECK(output_dim >= 0, "output_dim should >= 0");
MACE_CHECK(starts_[0] < input_dim
&& output_dim <= input_dim
&& ends_[0] <= input_dim)
<< "The starts and ends caused over range error.";
const index_t frames =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
std::multiplies<index_t>());
std::vector<index_t> output_shape = input->shape();
output_shape[rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < frames; ++i) {
const T *input_base =
input_data + i * input_dim + offset;
T *output_base =
output_data + i * output_dim;
memcpy(output_base, input_base, output_dim * sizeof(T));
}
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int> axes_;
std::vector<int> starts_;
std::vector<int> ends_;
};
void RegisterSlice(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Slice", SliceOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SliceOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int offset,
const int output_dim,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("Slice", "SliceTest")
.Input("Input")
.Output("Output")
.AddIntsArg("axes", {-1})
.AddIntsArg("starts", {offset})
.AddIntsArg("ends", {offset + output_dim})
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SliceOpTest, Simple2Dim) {
TestSlice<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
2, 3, {3, 3},
{3, 4, 5, 8, 9, 10, 13, 14, 15});
}
TEST_F(SliceOpTest, Simple3Dim) {
TestSlice<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
1, 2, {2, 3, 2},
{2, 3, 7, 8, 12, 13, 2, 3, 7, 8, 12, 13});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for SpliceComponent in Kaldi.
// It splices a context window of frames together [over time]
// (copy and append the frame whose time-index in in context_)
// The context_ values indicate which frame (over time) to splice.
// if context value is less than the first time-index,
// copy and append the first frame's dada,
// when context value is larger than frame's count,
// copy and append the last frame's data.
// i.e., give input data: [[1, 2, 3], [4, 5, 6]],
// with input-dim = 3, frame count = 2, context = [-1, 0, 1]
// Then, the output should be:
// [1, 2, 3, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 4, 5, 6]
// if const_component_dim_ != 0, const_dim_ will be used to determine which
// row of "in" we copy the last part of each row of "out" from (this part is
// not subject to splicing, it's assumed constant for each frame of "input".
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SpliceOp;
template <typename T>
class SpliceOp<DeviceType::CPU, T> : public Operation {
public:
explicit SpliceOp(OpConstructContext *context)
: Operation(context),
context_(Operation::GetRepeatedArgs<int>("context")),
const_dim_(
Operation::GetOptionalArg<int>("const_component_dim", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
MACE_CHECK(context_.size() > 0)
<< "The context param should not be empty in Splice Op.";
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t frames =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
std::multiplies<index_t>());
const index_t rank = input->dim_size();
const index_t input_dim = input_shape[rank - 1];
const index_t num_splice = static_cast<index_t>(context_.size());
const index_t dim = input_dim - const_dim_;
MACE_CHECK(input_dim > const_dim_,
"input dim should be greater than const dim.");
const index_t output_dim = dim * num_splice + const_dim_;
std::vector<index_t> output_shape = input->shape();
output_shape[rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < frames; ++i) {
for (index_t c = 0; c < num_splice; ++c) {
const index_t offset =
Clamp<index_t>(context_[c] + i, 0, frames - 1);
T *output_base = output_data + i * output_dim + c * dim;
const T *input_base = input_data + offset * input_dim;
memcpy(output_base, input_base, dim * sizeof(T));
}
}
if (const_dim_ > 0) {
const index_t output_offset = output_dim - const_dim_;
const index_t input_offset = dim;
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < frames; ++i) {
index_t offset = i + context_[0] >= 0 ? i + context_[0] : 0;
T *output_base = output_data + i * output_dim;
const T *input_base = input_data + offset * input_dim;
memcpy(output_base + output_offset,
input_base + input_offset,
const_dim_ * sizeof(T));
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int> context_;
int const_dim_;
};
void RegisterSplice(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Splice", SpliceOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template<DeviceType D, typename T>
void BMSpliceHelper(int iters,
const std::vector<index_t> &input_shape,
const index_t left_context,
const index_t right_context,
const int const_component_dim) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
const int num_splice = left_context + right_context + 1;
std::vector<int> contexts(num_splice);
for (int i = 0; i < num_splice; ++i) {
contexts[i] = left_context + i;
}
const index_t input_size = std::accumulate(input_shape.begin(),
input_shape.end(),
1,
std::multiplies<index_t>());
std::vector<float> input_data(input_size);
GenerateRandomRealTypeData(input_shape, &input_data);
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
OpDefBuilder("Splice", "SpliceTest")
.Input("Input")
.Output("Output")
.AddIntsArg("context", contexts)
.AddIntArg("const_component_dim", const_component_dim)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_SPLICE_MACRO(N, H, W, L, R, C, TYPE, DEVICE) \
static void \
MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::MacsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSpliceHelper<DEVICE, TYPE>(iters, {N, H, W}, L, R, C); \
} \
MACE_BENCHMARK( \
MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE)
#define MACE_BM_SPLICE(N, H, W, L, R, C) \
MACE_BM_SPLICE_MACRO(N, H, W, L, R, C, float, CPU);
MACE_BM_SPLICE(1, 32, 32, 5, 5, 10);
MACE_BM_SPLICE(1, 32, 32, 7, 7, 5);
MACE_BM_SPLICE(1, 32, 32, 3, 3, 20);
MACE_BM_SPLICE(1, 128, 128, 9, 9, 100);
MACE_BM_SPLICE(1, 128, 128, 7, 7, 100);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SpliceOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSplice(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const std::vector<int> &context,
const int const_dim,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("Splice", "SpliceTest")
.Input("Input")
.Output("Output")
.AddIntsArg("context", context)
.AddIntArg("const_component_dim", const_dim)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SpliceOpTest, WithoutConstDim) {
TestSplice<DeviceType::CPU, float>(
{1, 7, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{-2, -1, 0, 1, 2}, 0,
{1, 7, 10},
{1, 2, 1, 2, 1, 2, 3, 4, 5, 6,
1, 2, 1, 2, 3, 4, 5, 6, 7, 8,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
7, 8, 9, 10, 11, 12, 13, 14, 13, 14,
9, 10, 11, 12, 13, 14, 13, 14, 13, 14});
}
TEST_F(SpliceOpTest, WithConstDim) {
TestSplice<DeviceType::CPU, float>(
{1, 5, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{-2, -1, 0, 1, 2}, 7,
{1, 5, 22},
{1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10,
1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 4, 5, 6, 7, 8, 9, 10,
1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10,
2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 8, 9, 10, 11,
3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for SumGroupComponent in Kaldi.
// It's used to sum up groups of posteriors,
// and to introduce a kind of Gaussian-mixture-model-like
// idea into neural nets.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SumGroupOp;
template <typename T>
class SumGroupOp<DeviceType::CPU, T> : public Operation {
public:
explicit SumGroupOp(OpConstructContext *context)
: Operation(context) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
MACE_CHECK(this->InputSize() >= 2,
"SumGroup should have at least 2 inputs.");
const Tensor *input = this->Input(0);
// Sizes-input gets a vector saying, for
// each output-dim, how many
// inputs data were summed over.
const Tensor *sizes = this->Input(1);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() >= 1,
"SumGroup's input's rank should be >= 1.");
MACE_CHECK(sizes->dim_size() == 1,
"SumGroup's sizes input should be a vector.");
const std::vector<index_t> &input_shape = input->shape();
const index_t bh =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
std::vector<index_t> output_shape(input_shape);
const index_t output_dim = sizes->dim(0);
const index_t dim_size = input->dim_size();
const index_t input_dim = input_shape[dim_size -1];
output_shape[dim_size - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_sizes(sizes);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
const int *sizes_data = sizes->data<int>();
T *output_data = output->mutable_data<T>();
std::vector<std::pair<int, int>>
sum_indexes(static_cast<size_t >(output_dim));
int cur_index = 0;
for (index_t i = 0; i < output_dim; ++i) {
int size_value = sizes_data[i];
MACE_CHECK(size_value > 0, "size value should be > 0");
sum_indexes[i].first = cur_index;
cur_index += size_value;
sum_indexes[i].second = cur_index;
MACE_CHECK(cur_index <= input_dim)
<< "size value over-ranged:" << cur_index << "<=" << input_dim;
}
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim; ++j) {
int start_col = sum_indexes[j].first;
int end_col = sum_indexes[j].second;
T sum = 0;
for (int src_col = start_col; src_col < end_col; ++src_col) {
sum += input_data[i * input_dim + src_col];
}
output_data[i * output_dim + j] = sum;
}
}
return MaceStatus::MACE_SUCCESS;
}
};
void RegisterSumGroup(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "SumGroup", SumGroupOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void SumGroupBenchmark(int iters, int n, int h, int w) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {n, h, w});
net.AddRepeatedInput<D, int>("Sizes",
{w / 2},
2);
OpDefBuilder("SumGroup", "SumGroupBM")
.Input("Input")
.Input("Sizes")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_SUMGROUP_MACRO(N, H, W, TYPE, DEVICE) \
static void \
MACE_BM_SUMGROUP_##N##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SumGroupBenchmark<DEVICE, TYPE>(iters, N, H, W); \
} \
MACE_BENCHMARK( \
MACE_BM_SUMGROUP_##N##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_SUMGROUP(N, H, W) \
MACE_BM_SUMGROUP_MACRO(N, H, W, float, CPU);
MACE_BM_SUMGROUP(1, 10, 256);
MACE_BM_SUMGROUP(1, 20, 128);
MACE_BM_SUMGROUP(1, 10, 128);
MACE_BM_SUMGROUP(1, 20, 512);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SumGroupOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSumGroup(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const std::vector<int> &sizes,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
const index_t output_dim = sizes.size();
net.AddInputFromArray<CPU, int>(MakeString("Sizes"),
{output_dim},
sizes);
OpDefBuilder("SumGroup", "SumGroupTest")
.Input("Input")
.Input("Sizes")
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SumGroupOpTest, SimpleTest) {
TestSumGroup<DeviceType::CPU, float>(
{1, 5, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{2, 1, 2, 3, 2},
{1, 5, 5},
{3, 3, 9, 21, 19,
5, 4, 11, 24, 21,
7, 5, 13, 27, 23,
9, 6, 15, 30, 25,
11, 7, 17, 33, 27});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This op is implemented for kaldi's NormalizeComponent.
// The output y_i = scale * x_i,
// and we want the RMS value of the y_i equals to target_rms,
// so y^t y = Dim * target_rms^2 (if y is one row of the input).
// Dim is the length of a row.
// we need the scale = 1.0 / sqrt(x^t x / (Dim * target_rms^2)).
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class TargetRMSNormOp;
template <typename T>
class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
public:
explicit TargetRMSNormOp(OpConstructContext *context)
: Operation(context),
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0)) {}
// Calculate the square sum of an array
float SquareSum(const float *data, const index_t data_len) {
const int num_parts = 4;
float result = 0.0f;
if (data_len <= 2 * num_parts) {
for (index_t i = 0; i < data_len; ++i) {
result += data[i] * data[i];
}
} else {
const index_t part_len = data_len / num_parts;
const index_t left_len = data_len % num_parts;
float results[4] = {0.f, 0.f, 0.f, 0.f};
for (index_t i = 0; i < num_parts; ++i) {
for (index_t j = 0; j < part_len; ++j) {
results[i] += data[i * part_len + j] * data[i * part_len + j];
}
}
for (index_t k = 0; k < left_len; ++k) {
float d = data[num_parts * part_len + k];
results[3] += d * d;
}
for (index_t i = 0; i < num_parts; ++i) {
result += results[i];
}
}
return result;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t dim_size = input->dim_size();
MACE_CHECK(dim_size >= 1,
"TargetRMSNorm's input dim size should be >= 1.");
const index_t dim = input_shape[dim_size -1];
MACE_CHECK(dim > 0 && target_rms_ > 0,
"Both input dim and target rms should be greater than zero.");
const index_t bh =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
const float d_scale = dim * target_rms_ * target_rms_;
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
float scale = SquareSum(input_data + i * dim, dim);
scale = static_cast<float>(1.0 / std::sqrt(scale / d_scale));
for (index_t j = 0; j < dim; ++j) {
output_data[i * dim + j] = input_data[i * dim + j] * scale;
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
float target_rms_;
};
void RegisterTargetRMSNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "TargetRMSNorm", TargetRMSNormOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void TargetRMSNormBenchmark(int iters, int n, int h, int w, float target_rms) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {n, h, w});
OpDefBuilder("TargetRMSNorm", "TargetRMSNormBM")
.Input("Input")
.AddFloatArg("target_rms", target_rms)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_TARGETRMSNORM_MACRO(N, H, W, RMS, TYPE, DEVICE) \
static void \
MACE_BM_TARGETRMSNORM_##N##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TargetRMSNormBenchmark<DEVICE, TYPE>(iters, N, H, W, RMS); \
} \
MACE_BENCHMARK( \
MACE_BM_TARGETRMSNORM_##N##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_TARGETRMSNORM(N, H, W, RMS) \
MACE_BM_TARGETRMSNORM_MACRO(N, H, W, RMS, float, CPU);
MACE_BM_TARGETRMSNORM(1, 10, 256, 1.0);
MACE_BM_TARGETRMSNORM(1, 20, 128, 2.0);
MACE_BM_TARGETRMSNORM(1, 10, 128, 0.5);
MACE_BM_TARGETRMSNORM(1, 20, 512, 1.0);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class TargetRMSNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestTargetRMSNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const float target_rms,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("TargetRMSNorm", "TargetRMSNormTest")
.Input("Input")
.AddFloatArg("target_rms", target_rms)
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", input_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TargetRMSNormOpTest, SimpleTest) {
TestTargetRMSNorm<DeviceType::CPU, float>(
{1, 3, 3},
{1, 2, 3,
2, 3, 4,
3, 4, 5},
1.0,
{0.46291, 0.92582, 1.38873,
0.64327, 0.9649, 1.28654,
0.734847, 0.979796, 1.224745});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for offset descriptor in Kaldi.
// It defines time offset.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class TimeOffsetOp;
template <typename T>
class TimeOffsetOp<DeviceType::CPU, T> : public Operation {
public:
explicit TimeOffsetOp(OpConstructContext *context)
: Operation(context),
offset_(Operation::GetOptionalArg<int>("offset", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2.");
const std::vector<index_t> &input_shape = input->shape();
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies<index_t>());
const index_t frames = input_shape[rank - 2];
const index_t input_dim = input_shape[rank - 1];
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < batch; ++i) {
for (index_t j = 0; j < frames; ++j) {
index_t time_index = offset_ + j;
index_t index = Clamp<index_t>(time_index, 0, frames - 1);
T *output_base = output_data + (i * frames + j) * input_dim;
const T *input_base = input_data + (i * frames + index) * input_dim;
memcpy(output_base, input_base, input_dim * sizeof(T));
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int offset_;
};
void RegisterTimeOffset(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "TimeOffset", TimeOffsetOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template<DeviceType D, typename T>
void TimeOffsetBenchmark(int iters,
std::vector<index_t> shape,
int offset) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", shape);
OpDefBuilder("TimeOffset", "TimeOffsetBM")
.Input("Input")
.Output("Output")
.AddIntArg("offset", offset)
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_TIMEOFFSET2D_MACRO(H, W, TYPE, DEVICE) \
static void MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TimeOffsetBenchmark<DEVICE, TYPE>(iters, {H, W}, 1); \
} \
MACE_BENCHMARK(MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE) \
#define MACE_BM_TIMEOFFSET2D(H, W) \
MACE_BM_TIMEOFFSET2D_MACRO(H, W, float, CPU);
MACE_BM_TIMEOFFSET2D(20, 128);
MACE_BM_TIMEOFFSET2D(40, 512);
MACE_BM_TIMEOFFSET2D(1, 1024);
MACE_BM_TIMEOFFSET2D(20, 2048);
MACE_BM_TIMEOFFSET2D(20, 512);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class TimeOffsetOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestTimeOffset(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int offset,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("TimeOffset", "TimeOffsetTest")
.Input("Input")
.Output("Output")
.AddIntArg("offset", offset)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", input_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TimeOffsetOpTest, Simple2Dim) {
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-2,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-1,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
0,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
1,
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
2,
{11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
}
TEST_F(TimeOffsetOpTest, Simple3Dim) {
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-2,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-1,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
0,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
1,
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
2,
{11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15,
11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -98,6 +98,7 @@ class FrameworkType(Enum): ...@@ -98,6 +98,7 @@ class FrameworkType(Enum):
MaceSupportedOps = [ MaceSupportedOps = [
'Activation', 'Activation',
'AddN', 'AddN',
'Affine',
'ArgMax', 'ArgMax',
'BatchNorm', 'BatchNorm',
'BatchToSpaceND', 'BatchToSpaceND',
...@@ -121,8 +122,10 @@ MaceSupportedOps = [ ...@@ -121,8 +122,10 @@ MaceSupportedOps = [
'InferConv2dShape', 'InferConv2dShape',
'LocalResponseNorm', 'LocalResponseNorm',
'LSTMCell', 'LSTMCell',
# 'LstmNonlinear',
'MatMul', 'MatMul',
'Pad', 'Pad',
'PNorm',
'Pooling', 'Pooling',
'PriorBox', 'PriorBox',
'Proposal', 'Proposal',
...@@ -134,6 +137,8 @@ MaceSupportedOps = [ ...@@ -134,6 +137,8 @@ MaceSupportedOps = [
'ResizeNearestNeighbor', 'ResizeNearestNeighbor',
'Reverse', 'Reverse',
'ScalarMath', 'ScalarMath',
'Slice',
'Splice',
'Split', 'Split',
'Shape', 'Shape',
'Squeeze', 'Squeeze',
...@@ -144,6 +149,9 @@ MaceSupportedOps = [ ...@@ -144,6 +149,9 @@ MaceSupportedOps = [
'SpaceToBatchND', 'SpaceToBatchND',
'SpaceToDepth', 'SpaceToDepth',
'SqrDiffMean', 'SqrDiffMean',
'SumGroup',
'TargetRMSNorm',
'TimeOffset',
'Transpose', 'Transpose',
'WinogradInverseTransform', 'WinogradInverseTransform',
'WinogradTransform', 'WinogradTransform',
...@@ -159,6 +167,7 @@ class MaceKeyword(object): ...@@ -159,6 +167,7 @@ class MaceKeyword(object):
mace_buffer_type = 'buffer_type' mace_buffer_type = 'buffer_type'
# arg related str # arg related str
mace_padding_str = 'padding' mace_padding_str = 'padding'
mace_padding_type_str = 'padding'
mace_padding_values_str = 'padding_values' mace_padding_values_str = 'padding_values'
mace_strides_str = 'strides' mace_strides_str = 'strides'
mace_dilations_str = 'dilations' mace_dilations_str = 'dilations'
...@@ -473,6 +482,7 @@ class ConverterOption(object): ...@@ -473,6 +482,7 @@ class ConverterOption(object):
# Model data format related transformation # Model data format related transformation
TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT, TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
# Add winograd argument # Add winograd argument
TransformerRule.ADD_WINOGRAD_ARG, TransformerRule.ADD_WINOGRAD_ARG,
# Mace model structure related transformation # Mace model structure related transformation
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#ifndef MACE_UTILS_UTILS_H_ #ifndef MACE_UTILS_UTILS_H_
#define MACE_UTILS_UTILS_H_ #define MACE_UTILS_UTILS_H_
#include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -69,6 +71,32 @@ Integer CeilQuotient(Integer a, Integer b) { ...@@ -69,6 +71,32 @@ Integer CeilQuotient(Integer a, Integer b) {
std::string ObfuscateString(const std::string &src, std::string ObfuscateString(const std::string &src,
const std::string &lookup_table); const std::string &lookup_table);
template <typename Integer>
inline Integer Clamp(Integer in, Integer low, Integer high) {
return std::max<Integer>(low, std::min<Integer>(in, high));
}
template <typename T>
inline T ScalarSigmoid(T in) {
if (in > static_cast<T>(0)) {
return static_cast<T>(1) / (static_cast<T>(1) + std::exp(-in));
} else {
T x = std::exp(in);
return x / (x + static_cast<T>(1));
}
}
template <typename T>
inline T ScalarTanh(T in) {
if (in > static_cast<T>(0)) {
T inv_expa = std::exp(-in);
return -static_cast<T>(1) +
static_cast<T>(2) / (static_cast<T>(1) + inv_expa * inv_expa);
} else {
T x = std::exp(in);
return x / (x + static_cast<T>(1));
}
}
std::string ObfuscateString(const std::string &src); std::string ObfuscateString(const std::string &src);
......
...@@ -401,6 +401,7 @@ class YAMLKeyword(object): ...@@ -401,6 +401,7 @@ class YAMLKeyword(object):
graph_optimize_options = 'graph_optimize_options' # internal use for now graph_optimize_options = 'graph_optimize_options' # internal use for now
cl_mem_type = 'cl_mem_type' cl_mem_type = 'cl_mem_type'
backend = 'backend' backend = 'backend'
validation_outputs_data = 'validation_outputs_data'
docker_image_tag = 'docker_image_tag' docker_image_tag = 'docker_image_tag'
dockerfile_path = 'dockerfile_path' dockerfile_path = 'dockerfile_path'
dockerfile_sha256_checksum = 'dockerfile_sha256_checksum' dockerfile_sha256_checksum = 'dockerfile_sha256_checksum'
......
...@@ -476,6 +476,14 @@ def format_model_config(flags): ...@@ -476,6 +476,14 @@ def format_model_config(flags):
onnx_backend = subgraph.get( onnx_backend = subgraph.get(
YAMLKeyword.backend, "tensorflow") YAMLKeyword.backend, "tensorflow")
subgraph[YAMLKeyword.backend] = onnx_backend subgraph[YAMLKeyword.backend] = onnx_backend
validation_outputs_data = subgraph.get(
YAMLKeyword.validation_outputs_data, [])
if not isinstance(validation_outputs_data, list):
subgraph[YAMLKeyword.validation_outputs_data] = [
validation_outputs_data]
else:
subgraph[YAMLKeyword.validation_outputs_data] = \
validation_outputs_data
input_ranges = subgraph.get( input_ranges = subgraph.get(
YAMLKeyword.input_ranges, []) YAMLKeyword.input_ranges, [])
if not isinstance(input_ranges, list): if not isinstance(input_ranges, list):
......
...@@ -660,6 +660,8 @@ class DeviceWrapper: ...@@ -660,6 +660,8 @@ class DeviceWrapper:
YAMLKeyword.validation_threshold][ YAMLKeyword.validation_threshold][
validate_type], validate_type],
backend=subgraphs[0][YAMLKeyword.backend], backend=subgraphs[0][YAMLKeyword.backend],
validation_outputs_data=subgraphs[0][
YAMLKeyword.validation_outputs_data],
log_file=log_file, log_file=log_file,
) )
if flags.report and flags.round > 0: if flags.report and flags.round > 0:
......
...@@ -656,9 +656,12 @@ def validate_model(abi, ...@@ -656,9 +656,12 @@ def validate_model(abi,
output_file_name="model_out", output_file_name="model_out",
validation_threshold=0.9, validation_threshold=0.9,
backend="tensorflow", backend="tensorflow",
log_file="", validation_outputs_data=[],
): log_file=""):
six.print_("* Validate with %s" % platform) if not validation_outputs_data:
six.print_("* Validate with %s" % platform)
else:
six.print_("* Validate with file: %s" % validation_outputs_data)
if abi != "host": if abi != "host":
for output_name in output_nodes: for output_name in output_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
...@@ -675,6 +678,7 @@ def validate_model(abi, ...@@ -675,6 +678,7 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file) log_file)
elif platform == "onnx": elif platform == "onnx":
validate(platform, model_file_path, "", validate(platform, model_file_path, "",
...@@ -683,6 +687,7 @@ def validate_model(abi, ...@@ -683,6 +687,7 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file) log_file)
elif platform == "caffe": elif platform == "caffe":
image_name = "mace-caffe:" + docker_image_tag image_name = "mace-caffe:" + docker_image_tag
...@@ -700,6 +705,7 @@ def validate_model(abi, ...@@ -700,6 +705,7 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file) log_file)
elif caffe_env == common.CaffeEnvType.DOCKER: elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name) docker_image_id = sh.docker("images", "-q", image_name)
...@@ -767,6 +773,8 @@ def validate_model(abi, ...@@ -767,6 +773,8 @@ def validate_model(abi,
"--validation_threshold=%f" % validation_threshold, "--validation_threshold=%f" % validation_threshold,
"--input_data_type=%s" % ",".join(input_data_types), "--input_data_type=%s" % ",".join(input_data_types),
"--backend=%s" % ",".join(backend), "--backend=%s" % ",".join(backend),
"--validation_outputs_data=%s" % ",".join(
validation_outputs_data),
"--log_file=%s" % log_file, "--log_file=%s" % log_file,
_fg=True) _fg=True)
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import os.path import os.path
import numpy as np import numpy as np
import re import re
import six
import common import common
...@@ -121,6 +122,32 @@ def normalize_tf_tensor_name(name): ...@@ -121,6 +122,32 @@ def normalize_tf_tensor_name(name):
return name return name
def validate_with_file(platform, device_type,
output_names, output_shapes,
mace_out_file, validation_outputs_data,
validation_threshold, log_file):
for i in range(len(output_names)):
if validation_outputs_data[i].startswith("http://") or \
validation_outputs_data[i].startswith("https://"):
validation_file_name = common.formatted_file_name(
mace_out_file, output_names[i] + '_validation')
six.moves.urllib.request.urlretrieve(validation_outputs_data[i],
validation_file_name)
else:
validation_file_name = validation_outputs_data[i]
value = load_data(validation_file_name)
out_shape = output_shapes[i]
if len(out_shape) == 4:
out_shape[1], out_shape[2], out_shape[3] = \
out_shape[3], out_shape[1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, device_type, output_names[i], mace_out_value,
value, validation_threshold, log_file)
def validate_tf_model(platform, device_type, model_file, input_file, def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold, input_data_types, output_names, validation_threshold, input_data_types,
...@@ -275,7 +302,8 @@ def validate_onnx_model(platform, device_type, model_file, input_file, ...@@ -275,7 +302,8 @@ def validate_onnx_model(platform, device_type, model_file, input_file,
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node, device_type, input_shape, output_shape, input_node, output_node,
validation_threshold, input_data_type, backend, log_file): validation_threshold, input_data_type, backend,
validation_outputs_data, log_file):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
...@@ -287,8 +315,21 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, ...@@ -287,8 +315,21 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_data_types = ['float32'] * len(input_names) input_data_types = ['float32'] * len(input_names)
output_names = [name for name in output_node.split(',')] output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
if not isinstance(validation_outputs_data, list):
if platform == 'tensorflow': if os.path.isfile(validation_outputs_data):
validation_outputs = [validation_outputs_data]
else:
validation_outputs = []
else:
validation_outputs = validation_outputs_data
if validation_outputs:
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs]
validate_with_file(platform, device_type, output_names, output_shapes,
mace_out_file, validation_outputs,
validation_threshold, log_file)
elif platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file, validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold, input_data_types, output_names, validation_threshold, input_data_types,
...@@ -358,10 +399,10 @@ def parse_args(): ...@@ -358,10 +399,10 @@ def parse_args():
default="tensorflow", default="tensorflow",
help="onnx backend framwork") help="onnx backend framwork")
parser.add_argument( parser.add_argument(
"--log_file", "--validation_outputs_data", type=str,
type=str, default="", help="validation outputs data file path.")
default="", parser.add_argument(
help="log file") "--log_file", type=str, default="", help="log file.")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -381,4 +422,5 @@ if __name__ == '__main__': ...@@ -381,4 +422,5 @@ if __name__ == '__main__':
FLAGS.validation_threshold, FLAGS.validation_threshold,
FLAGS.input_data_type, FLAGS.input_data_type,
FLAGS.backend, FLAGS.backend,
FLAGS.validation_outputs_data,
FLAGS.log_file) FLAGS.log_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册