提交 c057b7ec 编写于 作者: L liutuo

add depthwise deconv

上级 487ad9dc
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_COMMON_NEON_H_
#define MACE_OPS_ARM_COMMON_NEON_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
namespace mace {
namespace ops {
#ifdef MACE_ENABLE_NEON
inline float32x4_t neon_vfma_lane_0(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 0);
#else
return vmlaq_lane_f32(a, b, vget_low_f32(c), 0);
#endif
}
inline float32x4_t neon_vfma_lane_1(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 1);
#else
return vmlaq_lane_f32(a, b, vget_low_f32(c), 1);
#endif
}
inline float32x4_t neon_vfma_lane_2(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 2);
#else
return vmlaq_lane_f32(a, b, vget_high_f32(c), 0);
#endif
}
inline float32x4_t neon_vfma_lane_3(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 3);
#else
return vmlaq_lane_f32(a, b, vget_high_f32(c), 1);
#endif
}
#endif
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_COMMON_NEON_H_
......@@ -15,11 +15,8 @@
#ifndef MACE_OPS_ARM_DECONV_2D_NEON_H_
#define MACE_OPS_ARM_DECONV_2D_NEON_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/ops/arm/common_neon.h"
namespace mace {
namespace ops {
......@@ -48,48 +45,6 @@ void Deconv2dNeonK4x4S2(const float *input,
const index_t *out_shape,
float *output);
#ifdef MACE_ENABLE_NEON
inline float32x4_t neon_vfma_lane_0(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 0);
#else
return vmlaq_lane_f32(a, b, vget_low_f32(c), 0);
#endif
}
inline float32x4_t neon_vfma_lane_1(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 1);
#else
return vmlaq_lane_f32(a, b, vget_low_f32(c), 1);
#endif
}
inline float32x4_t neon_vfma_lane_2(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 2);
#else
return vmlaq_lane_f32(a, b, vget_high_f32(c), 0);
#endif
}
inline float32x4_t neon_vfma_lane_3(float32x4_t a,
float32x4_t b,
float32x4_t c) {
#ifdef __aarch64__
return vfmaq_laneq_f32(a, b, c, 3);
#else
return vmlaq_lane_f32(a, b, vget_high_f32(c), 1);
#endif
}
#endif
} // namespace ops
} // namespace mace
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_DEPTHWISE_DECONV2D_NEON_H_
#define MACE_OPS_ARM_DEPTHWISE_DECONV2D_NEON_H_
#include "mace/core/types.h"
#include "mace/ops/arm/common_neon.h"
namespace mace {
namespace ops {
void DepthwiseDeconv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void DepthwiseDeconv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void DepthwiseDeconv2dNeonK4x4S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void DepthwiseDeconv2dNeonK4x4S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK3x3S1(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK3x3S2(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK4x4S1(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK4x4S2(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_DEPTHWISE_DECONV2D_NEON_H_
此差异已折叠。
此差异已折叠。
......@@ -359,12 +359,12 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase {
padded_out_shape.data(),
out_data);
if (!no_pad) {
CropPadOut(out_data,
padded_out_shape.data(),
output_shape.data(),
pad_h,
pad_w,
output_data);
CropPadOut<float>(out_data,
padded_out_shape.data(),
output_shape.data(),
pad_h,
pad_w,
output_data);
}
if (bias_data != nullptr) {
......@@ -445,33 +445,6 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase {
}
}
}
void CropPadOut(const float *input,
const index_t *in_shape,
const index_t *out_shape,
const index_t pad_h,
const index_t pad_w,
float *output) {
const index_t batch = in_shape[0];
const index_t channel = in_shape[1];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
#pragma omp parallel for collapse(3)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channel; ++j) {
for (int k = 0; k < out_height; ++k) {
const float *input_base =
input + ((i * channel + j) * in_height + (k + pad_h)) * in_width;
float *output_base =
output + ((i * channel + j) * out_height + k)* out_width;
memcpy(output_base, input_base + pad_w, out_width * sizeof(float));
}
}
}
}
};
#ifdef MACE_ENABLE_OPENCL
......
......@@ -15,6 +15,8 @@
#ifndef MACE_OPS_DECONV_2D_H_
#define MACE_OPS_DECONV_2D_H_
#include "mace/core/types.h"
namespace mace {
namespace ops {
......@@ -23,6 +25,34 @@ enum FrameworkType {
CAFFE = 1,
};
template <typename T>
void CropPadOut(const T *input,
const index_t *in_shape,
const index_t *out_shape,
const index_t pad_h,
const index_t pad_w,
T *output) {
const index_t batch = in_shape[0];
const index_t channel = in_shape[1];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
#pragma omp parallel for collapse(3)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channel; ++j) {
for (int k = 0; k < out_height; ++k) {
const T *input_base =
input + ((i * channel + j) * in_height + (k + pad_h)) * in_width;
T *output_base =
output + ((i * channel + j) * out_height + k)* out_width;
memcpy(output_base, input_base + pad_w, out_width * sizeof(T));
}
}
}
}
} // namespace ops
} // namespace mace
......
此差异已折叠。
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void DepthwiseDeconv2d(int iters,
int batch,
int channels,
int height,
int width,
int kernel_h,
int kernel_w,
int stride,
int padding) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
net.AddRandomInput<D, float>("Filter",
{1, channels, kernel_h,
kernel_w});
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
ops::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage",
ops::BufferType::DW_CONV2D_FILTER);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntsArg("padding_values", {padding, padding})
.AddIntArg("group", channels)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("Input")
.Input("Filter")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntsArg("padding_values", {padding, padding})
.AddIntArg("group", channels)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
}
net.Setup(D);
// Warm-up
for (int i = 0; i < 2; ++i) {
net.Run();
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.Run();
net.Sync();
}
}
// In common network, there are usually more than 1 layers, this is used to
// approximate the amortized latency. The OpenCL runtime for Mali/Adreno is
// in-order.
#define MACE_BM_DEPTHWISE_DECONV2D_MACRO( \
N, C, H, W, KH, KW, S, P, TYPE, DEVICE) \
static void \
MACE_BM_DEPTHWISE_DECONV2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##S##_##P\
##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t macc = \
static_cast<int64_t>(iters) * N * H * W * KH * KW * C; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthwiseDeconv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, S, P); \
} \
MACE_BENCHMARK( \
MACE_BM_DEPTHWISE_DECONV2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##S##_##P\
##_##TYPE##_##DEVICE)
#define MACE_BM_DEPTHWISE_DECONV2D(N, C, H, W, KH, KW, S, P) \
MACE_BM_DEPTHWISE_DECONV2D_MACRO(N, C, H, W, KH, KW, S, P, float, CPU); \
MACE_BM_DEPTHWISE_DECONV2D_MACRO(N, C, H, W, KH, KW, S, P, float, GPU); \
MACE_BM_DEPTHWISE_DECONV2D_MACRO(N, C, H, W, KH, KW, S, P, half, GPU);
MACE_BM_DEPTHWISE_DECONV2D(1, 128, 15, 15, 1, 1, 1, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 32, 60, 60, 1, 1, 1, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 32, 60, 60, 3, 3, 1, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 128, 60, 60, 4, 4, 1, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 3, 224, 224, 4, 4, 2, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 3, 512, 512, 7, 7, 2, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 128, 16, 16, 5, 5, 1, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 64, 32, 32, 1, 1, 1, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 64, 33, 32, 3, 3, 2, 0);
MACE_BM_DEPTHWISE_DECONV2D(1, 3, 224, 224, 3, 3, 2, 0);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <vector>
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class DepthwiseDeconv2dOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void RunTestSimple(const int group,
const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<float> &bias_data,
const int stride,
const std::vector<int> &paddings,
const std::vector<index_t> &filter_shape,
const std::vector<float> &filter_data,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data) {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
net.TransformDataFormat<D, float>("Filter", HWOI, "FilterOIHW", OIHW);
const index_t out_channels = expected_shape[3];
net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data);
if (D == DeviceType::GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
ops::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "FilterOIHW", "FilterImage",
ops::BufferType::DW_CONV2D_FILTER);
BufferToImage<D, float>(&net, "Bias", "BiasImage",
ops::BufferType::ARGUMENT);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("group", group)
.AddIntsArg("padding_values", paddings)
.Finalize(net.NewOperatorDef());
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
ops::BufferType::IN_OUT_CHANNEL);
} else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC,
"InputNCHW", NCHW);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntArg("group", group)
.AddIntsArg("strides", {stride, stride})
.AddIntsArg("padding_values", paddings)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
}
auto expected = net.CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.0001);
}
template <DeviceType D>
void TestNHWCSimple3x3_DW() {
RunTestSimple<D>(3,
{1, 3, 3, 3},
{1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1},
{0, 0, 0},
1, {0, 0},
{3, 3, 1, 3},
{1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 5, 5, 3},
{1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 1, 1, 1,
2, 2, 2, 4, 4, 4, 6, 6, 6, 4, 4, 4, 2, 2, 2,
3, 3, 3, 6, 6, 6, 9, 9, 9, 6, 6, 6, 3, 3, 3,
2, 2, 2, 4, 4, 4, 6, 6, 6, 4, 4, 4, 2, 2, 2,
1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 1, 1, 1});
}
template <DeviceType D>
void TestNHWCSimple3x3_Group() {
RunTestSimple<D>(2,
{1, 3, 3, 4},
{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4},
{0, 0, 0, 0, 0, 0},
1, {0, 0},
{3, 3, 3, 4},
{1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1},
{1, 5, 5, 6},
{3, 6, 3, 7, 14, 7,
6, 12, 6, 14, 28, 14,
9, 18, 9, 21, 42, 21,
6, 12, 6, 14, 28, 14,
3, 6, 3, 7, 14, 7,
6, 12, 6, 14, 28, 14,
12, 24, 12, 28, 56, 28,
18, 36, 18, 42, 84, 42,
12, 24, 12, 28, 56, 28,
6, 12, 6, 14, 28, 14,
9, 18, 9, 21, 42, 21,
18, 36, 18, 42, 84, 42,
27, 54, 27, 63, 126, 63,
18, 36, 18, 42, 84, 42,
9, 18, 9, 21, 42, 21,
6, 12, 6, 14, 28, 14,
12, 24, 12, 28, 56, 28,
18, 36, 18, 42, 84, 42,
12, 24, 12, 28, 56, 28,
6, 12, 6, 14, 28, 14,
3, 6, 3, 7, 14, 7,
6, 12, 6, 14, 28, 14,
9, 18, 9, 21, 42, 21,
6, 12, 6, 14, 28, 14,
3, 6, 3, 7, 14, 7});
}
} // namespace
TEST_F(DepthwiseDeconv2dOpTest, CPUSimple3X3Depthwise) {
TestNHWCSimple3x3_DW<DeviceType::CPU>();
}
TEST_F(DepthwiseDeconv2dOpTest, CPUSimple3X3Group) {
TestNHWCSimple3x3_Group<DeviceType::CPU>();
}
TEST_F(DepthwiseDeconv2dOpTest, GPUSimple3X3Depthwise) {
TestNHWCSimple3x3_DW<DeviceType::GPU>();
}
namespace {
template <typename T>
void RandomTest(index_t batch,
index_t channel,
index_t height,
index_t width,
index_t kernel,
int stride,
int padding) {
testing::internal::LogToStderr();
// Construct graph
OpsTestNet net;
int multiplier = 1;
// Add input data
std::vector<float> input_data(batch * height * width * channel);
GenerateRandomRealTypeData({batch, height, width, channel}, &input_data);
net.AddInputFromArray<DeviceType::GPU, float>("Input",
{batch,
height,
width,
channel},
input_data);
std::vector<float> filter_data(kernel * kernel * channel * multiplier);
GenerateRandomRealTypeData({multiplier, channel, kernel, kernel},
&filter_data);
net.AddInputFromArray<DeviceType::GPU, float>(
"Filter", {multiplier, channel, kernel, kernel}, filter_data);
std::vector<float> bias_data(channel * multiplier);
GenerateRandomRealTypeData({channel * multiplier}, &bias_data);
net.AddInputFromArray<DeviceType::GPU, float>("Bias",
{channel * multiplier},
bias_data);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputNCHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride, stride})
.AddIntsArg("padding_values", {padding, padding})
.AddIntArg("group", channel)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
// Check
auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output"));
BufferToImage<DeviceType::GPU, T>(&net, "Input", "InputImage",
ops::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::GPU, T>(&net, "Filter", "FilterImage",
ops::BufferType::DW_CONV2D_FILTER);
BufferToImage<DeviceType::GPU, T>(&net, "Bias", "BiasImage",
ops::BufferType::ARGUMENT);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride, stride})
.AddIntsArg("padding_values", {padding, padding})
.AddIntArg("group", channel)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
net.RunOp(DeviceType::GPU);
// Transfer output
ImageToBuffer<DeviceType::GPU, float>(&net, "OutputImage", "OPENCLOutput",
ops::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*expected, *net.GetOutput("OPENCLOutput"), 1e-5);
} else {
ExpectTensorNear<float>(*expected, *net.GetOutput("OPENCLOutput"), 1e-2);
}
}
TEST_F(DepthwiseDeconv2dOpTest, RandomTestFloat) {
RandomTest<float>(1, 32, 256, 256, 5, 1, 2);
RandomTest<float>(1, 3, 256, 256, 5, 1, 1);
RandomTest<float>(1, 3, 256, 256, 5, 2, 2);
RandomTest<float>(1, 3, 256, 256, 5, 1, 3);
RandomTest<float>(1, 3, 256, 256, 5, 2, 4);
RandomTest<float>(1, 4, 256, 256, 5, 1, 1);
RandomTest<float>(1, 4, 256, 256, 5, 2, 2);
RandomTest<float>(1, 4, 256, 256, 5, 1, 3);
RandomTest<float>(1, 4, 256, 256, 5, 2, 4);
}
//
TEST_F(DepthwiseDeconv2dOpTest, RandomTestHalf) {
RandomTest<half>(1, 32, 256, 256, 5, 1, 2);
RandomTest<half>(1, 3, 256, 256, 5, 1, 1);
RandomTest<half>(1, 3, 256, 256, 5, 2, 2);
RandomTest<half>(1, 3, 256, 256, 5, 1, 3);
RandomTest<half>(1, 3, 256, 256, 5, 2, 4);
RandomTest<half>(1, 4, 256, 256, 5, 1, 1);
RandomTest<half>(1, 4, 256, 256, 5, 2, 2);
RandomTest<half>(1, 4, 256, 256, 5, 1, 3);
RandomTest<half>(1, 4, 256, 256, 5, 2, 4);
}
} // namespace
} // namespace test
} // namespace ops
} // namespace mace
......@@ -163,4 +163,5 @@ __kernel void deconv_2d(OUT_OF_RANGE_PARAMS
out_pos.x += stride_w;
WRITE_IMAGET(output, out_pos, out4);
}
}
\ No newline at end of file
}
#include <common.h>
__kernel void depthwise_deconv2d(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t weights,
#ifdef BIAS
__read_only image2d_t bias,
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int out_width,
__private const int out_channel,
__private const int stride_h,
__private const int stride_w,
__private const float stride_h_r,
__private const float stride_w_r,
__private const int align_h,
__private const int align_w,
__private const int padding_h,
__private const int padding_w,
__private const int kernel_h,
__private const int kernel_w,
__private const int kernel_size,
__private const int out_channel_blocks)
{
const int c = get_global_id(0);
const int w_id = get_global_id(1);
const int hb = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (c >= global_size_dim0 || w_id >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
#endif
#ifdef BIAS
DATA_TYPE4 out0 =
READ_IMAGET(bias, SAMPLER, (int2)(c, 0));
DATA_TYPE4 out1 = out0;
DATA_TYPE4 out2 = out0;
DATA_TYPE4 out3 = out0;
DATA_TYPE4 out4 = out0;
#else
DATA_TYPE4 out0 = 0;
DATA_TYPE4 out1 = 0;
DATA_TYPE4 out2 = 0;
DATA_TYPE4 out3 = 0;
DATA_TYPE4 out4 = 0;
#endif
const int n_stride = mad(w_id, stride_w_r, 0);
const int mod_stride = w_id - mul24(n_stride, stride_w);
const int w = mad24(mul24(n_stride, 5), stride_w, mod_stride);
const int b = hb / out_height;
const int h = hb - mul24(b, out_height);
if (w < out_width) {
int start_x = floor((float) (w + align_w) * stride_w_r);
int start_y = (h + align_h) * stride_h_r;
start_y = max(0, start_y);
int f_start_x = mad24(start_x, stride_w, padding_w) - w;
int f_start_y = mad24(start_y, stride_h, padding_h) - h;
f_start_x = kernel_w - 1 - f_start_x;
f_start_y = kernel_h - 1 - f_start_y;
int2 in_pos;
int f_pos;
DATA_TYPE4 in0, in1, in2, in3, in4;
DATA_TYPE4 weight;
int idx_w0, idx_w1, idx_w2, idx_w3, idx_w4;
int index_x, index_y;
for (int f_y = f_start_y, idx_h = start_y ; f_y >= 0; f_y -= stride_h, ++idx_h) {
index_y = mad24(b, in_height, idx_h);
in_pos.y = select(index_y, -1, idx_h < 0 || idx_h >= in_height);
for (int f_x = f_start_x, idx_w = start_x; f_x >= 0; f_x -= stride_w, ++idx_w) {
idx_w0 = idx_w;
idx_w1 = idx_w + 1;
idx_w2 = idx_w + 2;
idx_w3 = idx_w + 3;
idx_w4 = idx_w + 4;
#define READ_INPUT(i) \
index_x = mad24(c, in_width, idx_w##i); \
in_pos.x = \
select(index_x, -1, idx_w##i < 0 || idx_w##i >= in_width); \
in##i = READ_IMAGET(input, SAMPLER, in_pos);
READ_INPUT(0);
READ_INPUT(1);
READ_INPUT(2);
READ_INPUT(3);
READ_INPUT(4);
#undef READ_INPUT
f_pos = mad24(f_y, kernel_w, f_x);
weight = READ_IMAGET(weights, SAMPLER, (int2)(f_pos, c));
out0 = mad(in0, weight, out0);
out1 = mad(in1, weight, out1);
out2 = mad(in2, weight, out2);
out3 = mad(in3, weight, out3);
out4 = mad(in4, weight, out4);
}
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
out4 = do_activation(out4, relux_max_limit);
#endif
int2 out_pos;
out_pos.y = hb;
int ow = w;
if (ow >= out_width) return;
out_pos.x = mad24(c, out_width, ow);
WRITE_IMAGET(output, out_pos, out0);
ow += stride_w;
if (ow >= out_width) return;
out_pos.x += stride_w;
WRITE_IMAGET(output, out_pos, out1);
ow += stride_w;
if (ow >= out_width) return;
out_pos.x += stride_w;
WRITE_IMAGET(output, out_pos, out2);
ow += stride_w;
if (ow >= out_width) return;
out_pos.x += stride_w;
WRITE_IMAGET(output, out_pos, out3);
ow += stride_w;
if (ow >= out_width) return;
out_pos.x += stride_w;
WRITE_IMAGET(output, out_pos, out4);
}
}
\ No newline at end of file
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_DEPTHWISE_DECONV2D_H_
#define MACE_OPS_OPENCL_DEPTHWISE_DECONV2D_H_
#include <vector>
#include "mace/ops/activation.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLDepthwiseDeconv2dKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *strides,
const int *padding_data,
const int group,
const ActivationType activation,
const float relux_max_limit,
const std::vector <index_t> &output_shape,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDepthwiseDeconv2dKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_DEPTHWISE_DECONV2D_H_
......@@ -79,10 +79,10 @@ MaceStatus Deconv2dKernel<T>::Compute(
const int stride_h = strides[0];
const int stride_w = strides[1];
MACE_CHECK(stride_w > 0 && stride_h > 0, "strides should be > 0.");
#define MACE_WIDTH_BLK 5
const int width_tile = 5;
const index_t n_strides = (width + stride_w - 1) / stride_w;
const index_t width_blocks =
((n_strides + MACE_WIDTH_BLK - 1) / MACE_WIDTH_BLK) * stride_w;
((n_strides + width_tile - 1) / width_tile) * stride_w;
const float stride_h_r = 1.f / static_cast<float>(stride_h);
const float stride_w_r = 1.f / static_cast<float>(stride_w);
const int padding_h = (padding_data[0] + 1) >> 1;
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_DEPTHWISE_DECONV2D_H_
#define MACE_OPS_OPENCL_IMAGE_DEPTHWISE_DECONV2D_H_
#include "mace/ops/opencl/depthwise_deconv2d.h"
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mace/core/op_context.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
template <typename T>
class DepthwiseDeconv2dKernel : public OpenCLDepthwiseDeconv2dKernel {
public:
MaceStatus Compute(
OpContext *context,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *strides,
const int *padding_data,
const int group,
const ActivationType activation,
const float relux_max_limit,
const std::vector<index_t> &output_shape,
Tensor *output) override;
private:
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector<index_t> input_shape_;
};
template <typename T>
MaceStatus DepthwiseDeconv2dKernel<T>::Compute(
OpContext *context,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int *strides,
const int *padding_data,
const int group,
const ActivationType activation,
const float relux_max_limit,
const std::vector<index_t> &output_shape,
Tensor *output) {
const index_t batch = output_shape[0];
const index_t height = output_shape[1];
const index_t width = output_shape[2];
const index_t channels = output_shape[3];
const index_t input_channels = input->dim(3);
const index_t multiplier = filter->dim(0);
MACE_CHECK(group == channels && group == input_channels && multiplier == 1,
"opencl image deconv only supports depthwise type group.");
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
const DataType dt = DataTypeToEnum<T>::value;
const index_t channel_blocks = RoundUpDiv4(channels);
const int stride_h = strides[0];
const int stride_w = strides[1];
MACE_CHECK(stride_w > 0 && stride_h > 0, "strides should be > 0.");
const int width_tile = 5;
const index_t n_strides = (width + stride_w - 1) / stride_w;
const index_t width_blocks =
((n_strides + width_tile - 1) / width_tile) * stride_w;
const float stride_h_r = 1.f / static_cast<float>(stride_h);
const float stride_w_r = 1.f / static_cast<float>(stride_w);
const int padding_h = (padding_data[0] + 1) >> 1;
const int padding_w = (padding_data[1] + 1) >> 1;
const int align_h = stride_h - 1 - padding_h;
const int align_w = stride_w - 1 - padding_w;
const int kernel_size = filter->dim(2) * filter->dim(3);
auto runtime = context->device()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_deconv2d");
built_options.emplace("-Ddepthwise_deconv2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("depthwise_deconv2d", kernel_name,
built_options, &kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)};
MACE_OUT_OF_RANGE_INIT(kernel_);
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(filter->opencl_image()));
if (bias != nullptr) {
kernel_.setArg(idx++, *(bias->opencl_image()));
}
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, relux_max_limit);
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(height));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, static_cast<int32_t>(channels));
kernel_.setArg(idx++, static_cast<int32_t>(stride_h));
kernel_.setArg(idx++, static_cast<int32_t>(stride_w));
kernel_.setArg(idx++, stride_h_r);
kernel_.setArg(idx++, stride_w_r);
kernel_.setArg(idx++, static_cast<int32_t>(align_h));
kernel_.setArg(idx++, static_cast<int32_t>(align_w));
kernel_.setArg(idx++, static_cast<int32_t>(padding_h));
kernel_.setArg(idx++, static_cast<int32_t>(padding_w));
kernel_.setArg(idx++, static_cast<int32_t>(filter->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(filter->dim(3)));
kernel_.setArg(idx++, static_cast<int32_t>(kernel_size));
kernel_.setArg(idx++, static_cast<int32_t>(channel_blocks));
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_);
std::string tuning_key =
Concat("depthwise_deconv2d_kernel_",
activation,
output->dim(0),
output->dim(1),
output->dim(2),
output->dim(3));
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_DEPTHWISE_DECONV2D_H_
......@@ -32,6 +32,7 @@ extern void RegisterCrop(OpRegistryBase *op_registry);
extern void RegisterDeconv2D(OpRegistryBase *op_registry);
extern void RegisterDepthToSpace(OpRegistryBase *op_registry);
extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry);
extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry);
extern void RegisterDequantize(OpRegistryBase *op_registry);
extern void RegisterEltwise(OpRegistryBase *op_registry);
extern void RegisterExpandDims(OpRegistryBase *op_registry);
......@@ -89,6 +90,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterDeconv2D(this);
ops::RegisterDepthToSpace(this);
ops::RegisterDepthwiseConv2d(this);
ops::RegisterDepthwiseDeconv2d(this);
ops::RegisterDequantize(this);
ops::RegisterEltwise(this);
ops::RegisterExpandDims(this);
......
......@@ -91,6 +91,7 @@ MaceSupportedOps = [
'Deconv2D',
'DepthToSpace',
'DepthwiseConv2d',
'DepthwiseDeconv2d',
'Dequantize',
'Eltwise',
'ExpandDims',
......@@ -183,6 +184,7 @@ class MaceKeyword(object):
mace_scalar_input_index_str = 'scalar_input_index'
mace_opencl_mem_type = "opencl_mem_type"
mace_framework_type_str = "framework_type"
mace_group_str = "group"
class TransformerRule(Enum):
......
......@@ -411,17 +411,14 @@ class CaffeConverter(base_converter.ConverterInterface):
def convert_deconv2d(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.convolution_param
is_depthwise = False
if param.HasField(caffe_group_str) and param.group > 1:
filter_data = caffe_op.blobs[0]
mace_check(param.group == filter_data.shape[0] and
filter_data.shape[1] == 1,
"Mace does not support group deconvolution yet")
is_depthwise = True
mace_check(is_depthwise is False,
"Mace do not support depthwise deconvolution yet")
op.type = MaceOp.Deconv2D.name
if param.HasField(caffe_group_str) and param.group > 1:
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
group_arg.i = param.group
op.type = MaceOp.DepthwiseDeconv2d.name
else:
op.type = MaceOp.Deconv2D.name
self.add_stride_pad_kernel_arg(param, op)
# dilation is specific for convolution in caffe
......
......@@ -36,6 +36,7 @@ class ShapeInference(object):
MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape,
MaceOp.Deconv2D.name: self.infer_shape_deconv,
MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape,
MaceOp.DepthwiseDeconv2d.name: self.infer_shape_deconv,
MaceOp.Eltwise.name: self.infer_shape_general,
MaceOp.BatchNorm.name: self.infer_shape_general,
MaceOp.AddN.name: self.infer_shape_general,
......@@ -159,11 +160,15 @@ class ShapeInference(object):
dilations = [1, 1]
round_func = math.floor
group_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_group_str)
output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa
# filter format: IOHW
output_shape[1] = filter_shape[1]
if group_arg is not None and group_arg.i > 1:
output_shape[1] = group_arg.i * filter_shape[1]
output_shape[2] = int(
round_func((input_shape[2] - 1) * strides[0] +
(filter_shape[2] - 1) * (dilations[0] - 1) +
......
......@@ -583,7 +583,7 @@ class Transformer(base_converter.ConverterInterface):
def fold_deconv_and_bn(self):
net = self._model
for op in net.op:
if (op.type == MaceOp.Deconv2D.name) \
if (op.type in [MaceOp.Deconv2D.name, MaceOp.DepthwiseDeconv2d]) \
and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BatchNorm.name:
......@@ -1365,7 +1365,8 @@ class Transformer(base_converter.ConverterInterface):
self.set_filter_format(FilterFormat.OIHW)
# deconv's filter's output channel and input channel is reversed
for op in net.op:
if op.type == MaceOp.Deconv2D.name \
if op.type in [MaceOp.Deconv2D.name,
MaceOp.DepthwiseDeconv2d] \
and op.input[1] not in transposed_deconv_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
......@@ -1427,11 +1428,17 @@ class Transformer(base_converter.ConverterInterface):
self.buffer_transform(op, 1, OpenCLBufferType.CONV2D_FILTER)
if len(op.input) >= 3:
self.buffer_transform(op, 2, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.Deconv2D.name:
self.buffer_transform(op, 1, OpenCLBufferType.CONV2D_FILTER)
elif op.type == MaceOp.Deconv2D.name\
or op.type == MaceOp.DepthwiseDeconv2d.name:
if op.type == MaceOp.Deconv2D.name:
self.buffer_transform(op, 1,
OpenCLBufferType.CONV2D_FILTER)
elif op.type == MaceOp.DepthwiseDeconv2d.name:
self.buffer_transform(op, 1,
OpenCLBufferType.DW_CONV2D_FILTER)
if ConverterUtil.get_arg(
op,
MaceKeyword.mace_framework_type_str).i ==\
MaceKeyword.mace_framework_type_str).i == \
FrameworkType.CAFFE.value:
if len(op.input) >= 3:
self.buffer_transform(op, 2, OpenCLBufferType.ARGUMENT)
......@@ -1456,8 +1463,10 @@ class Transformer(base_converter.ConverterInterface):
if len(op.input) >= 4:
self.buffer_transform(op, 3, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.MatMul.name and \
ConverterUtil.get_arg(op,
MaceKeyword.mace_winograd_filter_transformed) is not None: # noqa
ConverterUtil.get_arg(
op,
MaceKeyword.mace_winograd_filter_transformed
) is not None: # noqa
self.buffer_transform(op, 0, OpenCLBufferType.WINOGRAD_FILTER)
elif op.type == MaceOp.WinogradInverseTransform.name \
and len(op.input) >= 3:
......@@ -1467,8 +1476,10 @@ class Transformer(base_converter.ConverterInterface):
if len(op.input) >= 3:
self.buffer_transform(op, 2, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.Activation.name:
if ConverterUtil.get_arg(op,
MaceKeyword.mace_activation_type_str).s == ActivationType.PRELU.name: # noqa
if ConverterUtil.get_arg(
op,
MaceKeyword.mace_activation_type_str
).s == ActivationType.PRELU.name: # noqa
self.buffer_transform(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.LSTMCell.name:
if op.input[1] in self._consts:
......@@ -1793,24 +1804,24 @@ class Transformer(base_converter.ConverterInterface):
check_conv = False
check_deconv = False
if ops is not None and len(ops) == 1:
check_conv =\
ops[0].type in [MaceOp.Conv2D.name,
MaceOp.DepthwiseConv2d.name,
MaceOp.FullyConnected.name]\
and len(ops[0].input) >= 3\
and ops[0].input[2] == tensor.name
if len(ops[0].input) >= 3:
check_conv =\
ops[0].type in [MaceOp.Conv2D.name,
MaceOp.DepthwiseConv2d.name,
MaceOp.FullyConnected.name]\
and ops[0].input[2] == tensor.name
# in tensorflow deconv's bias is the forth input
if ops[0].type == MaceOp.Deconv2D.name:
if ops[0].type in [MaceOp.Deconv2D.name,
MaceOp.DepthwiseDeconv2d]:
from_caffe = ConverterUtil.get_arg(
ops[0],
MaceKeyword.mace_framework_type_str).i ==\
FrameworkType.CAFFE.value
if from_caffe:
check_deconv = len(ops[0].input) >= 3\
and ops[0].input[2] == tensor.name
if from_caffe and len(ops[0].input) >= 3:
check_deconv = ops[0].input[2] == tensor.name
else:
check_deconv = len(ops[0].input) >= 4\
and ops[0].input[3] == tensor.name
if len(ops[0].input) >= 4:
check_deconv = ops[0].input[3] == tensor.name
if check_conv or check_deconv:
if self._option.device == DeviceType.CPU.value:
conv_op = ops[0]
......
......@@ -37,6 +37,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/conv_2d_buffer.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/crop.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/deconv_2d.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/depthwise_deconv2d.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/depth_to_space.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/depthwise_conv2d.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/depthwise_conv2d_buffer.cl"))
......
......@@ -18,8 +18,6 @@ import os
import os.path
import numpy as np
import re
from scipy import spatial
from scipy import stats
import common
......@@ -60,14 +58,22 @@ def calculate_sqnr(expected, actual):
return signal_power_sum / (noise_power_sum + 1e-15)
def calculate_similarity(u, v, data_type=np.float64):
if u.dtype is not data_type:
u = u.astype(data_type)
if v.dtype is not data_type:
v = v.astype(data_type)
return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
def compare_output(platform, device_type, output_name, mace_out_value,
out_value, validation_threshold):
if mace_out_value.size != 0:
out_value = out_value.reshape(-1)
mace_out_value = mace_out_value.reshape(-1)
assert len(out_value) == len(mace_out_value)
similarity = (1 - spatial.distance.cosine(out_value, mace_out_value))
sqnr = calculate_sqnr(out_value, mace_out_value)
similarity = calculate_similarity(out_value, mace_out_value)
common.MaceLogger.summary(
output_name + ' MACE VS ' + platform.upper()
+ ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册