提交 13bca3d1 编写于 作者: U Unknown 提交者: liutuo

optimize code for cwise

上级 d46c3965
...@@ -73,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); ...@@ -73,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_CWise(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry);
...@@ -82,14 +83,12 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry); ...@@ -82,14 +83,12 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry); extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry); extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Neg(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_ReOrganize(OperatorRegistry *op_registry); extern void Register_ReOrganize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_ScalarMath(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
...@@ -111,6 +110,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -111,6 +110,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_ChannelShuffle(this); ops::Register_ChannelShuffle(this);
ops::Register_Concat(this); ops::Register_Concat(this);
ops::Register_Conv2D(this); ops::Register_Conv2D(this);
ops::Register_CWise(this);
ops::Register_DepthToSpace(this); ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this); ops::Register_DepthwiseConv2d(this);
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
...@@ -120,14 +120,12 @@ OperatorRegistry::OperatorRegistry() { ...@@ -120,14 +120,12 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_GlobalAvgPooling(this); ops::Register_GlobalAvgPooling(this);
ops::Register_ImageToBuffer(this); ops::Register_ImageToBuffer(this);
ops::Register_MatMul(this); ops::Register_MatMul(this);
ops::Register_Neg(this);
ops::Register_Pooling(this); ops::Register_Pooling(this);
ops::Register_Proposal(this); ops::Register_Proposal(this);
ops::Register_PSROIAlign(this); ops::Register_PSROIAlign(this);
ops::Register_ReOrganize(this); ops::Register_ReOrganize(this);
ops::Register_Reshape(this); ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this); ops::Register_ResizeBilinear(this);
ops::Register_ScalarMath(this);
ops::Register_Slice(this); ops::Register_Slice(this);
ops::Register_Softmax(this); ops::Register_Softmax(this);
ops::Register_SpaceToBatchND(this); ops::Register_SpaceToBatchND(this);
......
// //
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#ifndef MACE_KERNELS_SCALAR_MATH_H_ #ifndef MACE_KERNELS_CWISE_H_
#define MACE_KERNELS_SCALAR_MATH_H_ #define MACE_KERNELS_CWISE_H_
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
...@@ -14,27 +14,29 @@ ...@@ -14,27 +14,29 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
enum ScalarMathType { enum CWiseType {
MUL = 0, MUL = 0,
ADD = 1, ADD = 1,
MAX = 2, MAX = 2,
MIN = 3, MIN = 3,
SUB = 4, SUB = 4,
DIV = 5, DIV = 5,
NEG = 6,
ABS = 7,
}; };
struct ScalarMathFunctorBase { struct CWiseFunctorBase {
ScalarMathFunctorBase(const ScalarMathType type, const float coeff) CWiseFunctorBase(const CWiseType type, const float coeff)
: type_(type), coeff_(coeff) {} : type_(type), coeff_(coeff) {}
ScalarMathType type_; CWiseType type_;
float coeff_; float coeff_;
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct ScalarMathFunctor : ScalarMathFunctorBase { struct CWiseFunctor : CWiseFunctorBase {
ScalarMathFunctor(const ScalarMathType type, const float coeff) CWiseFunctor(const CWiseType type, const float coeff)
: ScalarMathFunctorBase(type, coeff) {} : CWiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
Tensor *output, Tensor *output,
...@@ -59,6 +61,18 @@ struct ScalarMathFunctor : ScalarMathFunctorBase { ...@@ -59,6 +61,18 @@ struct ScalarMathFunctor : ScalarMathFunctorBase {
output_ptr[i] = coeff_ + input_ptr[i]; output_ptr[i] = coeff_ + input_ptr[i];
} }
break; break;
case MAX:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max<T>(input_ptr[i], coeff_);
}
break;
case MIN:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min<T>(input_ptr[i], coeff_);
}
break;
case SUB: case SUB:
#pragma omp parallel for #pragma omp parallel for
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
...@@ -71,16 +85,29 @@ struct ScalarMathFunctor : ScalarMathFunctorBase { ...@@ -71,16 +85,29 @@ struct ScalarMathFunctor : ScalarMathFunctorBase {
output_ptr[i] = input_ptr[i] / coeff_; output_ptr[i] = input_ptr[i] / coeff_;
} }
break; break;
case NEG:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 0 - input_ptr[i];
}
break;
case ABS:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T val = input_ptr[i];
output_ptr[i] = (val > 0)? val : 0 - val;
}
break;
default: default:
LOG(FATAL) << "ScalarMath op not support type " << type_; LOG(FATAL) << "CWise op not support type " << type_;
} }
} }
}; };
template <typename T> template <typename T>
struct ScalarMathFunctor<DeviceType::OPENCL, T> : ScalarMathFunctorBase { struct CWiseFunctor<DeviceType::OPENCL, T> : CWiseFunctorBase {
ScalarMathFunctor(const ScalarMathType type, const float coeff) CWiseFunctor(const CWiseType type, const float coeff)
: ScalarMathFunctorBase(type, coeff) {} : CWiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
Tensor *output, Tensor *output,
...@@ -93,4 +120,4 @@ struct ScalarMathFunctor<DeviceType::OPENCL, T> : ScalarMathFunctorBase { ...@@ -93,4 +120,4 @@ struct ScalarMathFunctor<DeviceType::OPENCL, T> : ScalarMathFunctorBase {
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_SCALAR_MATH_H_ #endif // MACE_KERNELS_CWISE_H_
...@@ -41,7 +41,7 @@ struct EltwiseFunctor : EltwiseFunctorBase { ...@@ -41,7 +41,7 @@ struct EltwiseFunctor : EltwiseFunctorBase {
StatsFuture *future) { StatsFuture *future) {
Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1); Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
const T *input0_ptr = input0->data<T>(); const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>(); const T *input1_ptr = input1->data<T>();
...@@ -56,12 +56,12 @@ struct EltwiseFunctor : EltwiseFunctorBase { ...@@ -56,12 +56,12 @@ struct EltwiseFunctor : EltwiseFunctorBase {
} }
break; break;
case SUM: case SUM:
if (coeff_.empty()) { if (coeff_.empty()) {
#pragma omp parallel for #pragma omp parallel for
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input0_ptr[i] + input1_ptr[i]; output_ptr[i] = input0_ptr[i] + input1_ptr[i];
} }
} else { } else {
#pragma omp parallel for #pragma omp parallel for
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output_ptr[i] = output_ptr[i] =
...@@ -69,13 +69,13 @@ struct EltwiseFunctor : EltwiseFunctorBase { ...@@ -69,13 +69,13 @@ struct EltwiseFunctor : EltwiseFunctorBase {
} }
} }
break; break;
case MAX: case MAX:
#pragma omp parallel for #pragma omp parallel for
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max<T>(input0_ptr[i], input1_ptr[i]); output_ptr[i] = std::max<T>(input0_ptr[i], input1_ptr[i]);
} }
break; break;
case MIN: case MIN:
#pragma omp parallel for #pragma omp parallel for
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min<T>(input0_ptr[i], input1_ptr[i]); output_ptr[i] = std::min<T>(input0_ptr[i], input1_ptr[i]);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_NEGATIVE_H_
#define MACE_KERNELS_NEGATIVE_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct NegFunctor {
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < batch; ++n) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
index_t pos = (((n * height) + h) * width + w) * channels + c;
output_ptr[pos] = 0 - input_ptr[pos];
}
}
}
}
}
};
/*
template <>
void NegFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
*/
template <typename T>
struct NegFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_NEGATIVE_H_
#include <common.h> #include <common.h>
__kernel void scalar_math(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __kernel void cwise(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const float scalar, __private const float value,
__write_only image2d_t output) { __write_only image2d_t output) {
const int w = get_global_id(0); const int w = get_global_id(0);
const int hb = get_global_id(1); const int hb = get_global_id(1);
DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1; DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value};
in1.x = scalar;
in1.y = scalar;
in1.z = scalar;
in1.w = scalar;
DATA_TYPE4 out; DATA_TYPE4 out;
#if SCALAR_MATH_TYPE == 1
#if CWISE_TYPE == 0
out = in0 * in1;
#elif CWISE_TYPE == 1
out = in0 + in1; out = in0 + in1;
#elif SCALAR_MATH_TYPE == 4 #elif CWISE_TYPE == 2
out.x = fmax(in0.x, value);
out.y = fmax(in0.y, value);
out.z = fmax(in0.z, value);
out.z = fmax(in0.w, value);
#elif CWISE_TYPE == 3
out.x = fmin(in0.x, value);
out.y = fmin(in0.y, value);
out.z = fmin(in0.z, value);
out.z = fmin(in0.w, value);
#elif CWISE_TYPE == 4
out = in0 - in1; out = in0 - in1;
#elif SCALAR_MATH_TYPE == 0 #elif CWISE_TYPE == 5
out = in0 * in1;
#elif SCALAR_MATH_TYPE == 5
out = in0 / in1; out = in0 / in1;
#elif CWISE_TYPE == 6
in1 = (DATA_TYPE4)(0, 0, 0, 0);
out = in1 - in0;
#elif CWISE_TYPE == 7
out.x = fabs(in0.x);
out.y = fabs(in0.y);
out.z = fabs(in0.z);
out.w = fabs(in0.w);
#endif #endif
WRITE_IMAGET(output, (int2)(w, hb), out); WRITE_IMAGET(output, (int2)(w, hb), out);
......
...@@ -4,7 +4,12 @@ __kernel void depth_to_space( ...@@ -4,7 +4,12 @@ __kernel void depth_to_space(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__read_only image2d_t input, __read_only image2d_t input,
__private const int block_size, __private const int block_size,
__private const int output_depth, __private const int input_height,
__private const int input_width,
__private const int input_depth_blocks,
__private const int output_height,
__private const int output_width,
__private const int output_depth_blocks,
__write_only image2d_t output) { __write_only image2d_t output) {
const int out_d = get_global_id(0); const int out_d = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
...@@ -20,16 +25,21 @@ __kernel void depth_to_space( ...@@ -20,16 +25,21 @@ __kernel void depth_to_space(
const int output_width = get_global_size(1); const int output_width = get_global_size(1);
#endif #endif
if (out_d >= output_depth_blocks || out_h >= output_height || out_w >= output_width)
return;
const int out_pos = mad24(out_d, output_width, out_w); const int out_pos = mad24(out_d, output_width, out_w);
const int input_width = output_width / block_size;
const int in_h = out_h / block_size; const int in_h = out_h / block_size;
const int offset_h = out_h % block_size; const int offset_h = out_h % block_size;
const int in_w = out_w / block_size; const int in_w = out_w / block_size;
const int offset_w = out_w % block_size; const int offset_w = out_w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * output_depth; const int offset_d = (offset_h * block_size + offset_w) * output_depth_blocks;
const int in_d = out_d + offset_d; const int in_d = out_d + offset_d;
if (in_h >= input_height || in_w >= input_width || in_d >= input_depth_blocks)
return;
const int in_pos = mad24(in_d, input_width, in_w); const int in_pos = mad24(in_d, input_width, in_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h));
WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data);
...@@ -39,7 +49,12 @@ __kernel void space_to_depth( ...@@ -39,7 +49,12 @@ __kernel void space_to_depth(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__read_only image2d_t input, __read_only image2d_t input,
__private const int block_size, __private const int block_size,
__private const int input_depth, __private const int input_height,
__private const int input_width,
__private const int input_depth_blocks,
__private const int output_height,
__private const int output_width,
__private const int output_depth_blocks,
__write_only image2d_t output) { __write_only image2d_t output) {
const int d = get_global_id(0); const int d = get_global_id(0);
...@@ -57,14 +72,17 @@ __kernel void space_to_depth( ...@@ -57,14 +72,17 @@ __kernel void space_to_depth(
#endif #endif
const int in_pos = mad24(d, input_width, w); const int in_pos = mad24(d, input_width, w);
const int output_width = input_width / block_size;
const int out_h = h / block_size; const int out_h = h / block_size;
const int offset_h = h % block_size; const int offset_h = h % block_size;
const int out_w = w / block_size; const int out_w = w / block_size;
const int offset_w = w % block_size; const int offset_w = w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * input_depth; const int offset_d = (offset_h * block_size + offset_w) * input_depth_blocks;
const int out_d = d + offset_d; const int out_d = d + offset_d;
if (out_d >= output_depth_blocks || out_h >= output_height || out_w >= output_width)
return;
const int out_pos = mad24(out_d, output_width, out_w); const int out_pos = mad24(out_d, output_width, out_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h));
......
#include <common.h>
// Supported data types: half/float
__kernel void neg(__read_only image2d_t input,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width = get_global_size(1);
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = 0 - in;
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/kernels/scalar_math.h" #include "mace/kernels/cwise.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h" #include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h" #include "mace/utils/tuner.h"
...@@ -11,7 +11,7 @@ namespace mace { ...@@ -11,7 +11,7 @@ namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <typename T>
void ScalarMathFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
...@@ -27,12 +27,12 @@ void ScalarMathFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -27,12 +27,12 @@ void ScalarMathFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options; std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value; auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("scalar_math"); std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise");
built_options.emplace("-Dscalar_math=" + kernel_name); built_options.emplace("-Dcwise=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DSCALAR_MATH_TYPE=", type_)); built_options.emplace(MakeString("-DCWISE_TYPE=", type_));
kernel_ = runtime->BuildKernel("scalar_math", kernel_name, built_options); kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options);
} }
if (!IsVecEqual(input_shape_, input->shape())) { if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0; uint32_t idx = 0;
...@@ -46,12 +46,12 @@ void ScalarMathFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -46,12 +46,12 @@ void ScalarMathFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(batch_height_pixels)}; static_cast<uint32_t>(batch_height_pixels)};
const std::vector<uint32_t> lws = {64, 16, 1}; const std::vector<uint32_t> lws = {64, 16, 1};
std::stringstream ss; std::stringstream ss;
ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3); << "_" << output->dim(2) << "_" << output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
} }
template struct ScalarMathFunctor<DeviceType::OPENCL, float>; template struct CWiseFunctor<DeviceType::OPENCL, float>;
template struct ScalarMathFunctor<DeviceType::OPENCL, half>; template struct CWiseFunctor<DeviceType::OPENCL, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -20,26 +20,22 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -20,26 +20,22 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
const index_t input_width = input->dim(2); const index_t input_width = input->dim(2);
const index_t input_depth = input->dim(3); const index_t input_depth = input->dim(3);
int depth_blocks = 1;
const char *kernel_name = nullptr; const char *kernel_name = nullptr;
index_t kernel_width = input_width;
index_t output_height, output_width, output_depth; index_t output_height, output_width, output_depth;
if (d2s_) { if (d2s_) {
output_height = input_height * block_size_; output_height = input_height * block_size_;
output_width = input_width * block_size_; output_width = input_width * block_size_;
output_depth = input_depth / (block_size_ * block_size_); output_depth = input_depth / (block_size_ * block_size_);
depth_blocks = RoundUpDiv4(output_depth);
kernel_name = "depth_to_space"; kernel_name = "depth_to_space";
kernel_width = output_width;
} else { } else {
output_height = input_height / block_size_; output_height = input_height / block_size_;
output_width = input_width / block_size_; output_width = input_width / block_size_;
output_depth = input_depth * block_size_ * block_size_; output_depth = input_depth * block_size_ * block_size_;
depth_blocks = RoundUpDiv4(input_depth);
kernel_name = "space_to_depth"; kernel_name = "space_to_depth";
kernel_width = input_width;
} }
const index_t input_depth_blocks = RoundUpDiv4(input_depth);
const index_t output_depth_blocks = RoundUpDiv4(output_depth);
std::vector<index_t> output_shape = {batch, output_height, output_width, std::vector<index_t> output_shape = {batch, output_height, output_width,
output_depth}; output_depth};
...@@ -94,7 +90,12 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -94,7 +90,12 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
} }
kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int32_t>(block_size_)); kernel_.setArg(idx++, static_cast<int32_t>(block_size_));
kernel_.setArg(idx++, static_cast<int32_t>(depth_blocks)); kernel_.setArg(idx++, static_cast<int32_t>(input_height));
kernel_.setArg(idx++, static_cast<int32_t>(input_width));
kernel_.setArg(idx++, static_cast<int32_t>(input_depth_blocks));
kernel_.setArg(idx++, static_cast<int32_t>(output_height));
kernel_.setArg(idx++, static_cast<int32_t>(output_width));
kernel_.setArg(idx++, static_cast<int32_t>(output_depth_blocks));
kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape(); input_shape_ = input->shape();
......
...@@ -22,7 +22,7 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0, ...@@ -22,7 +22,7 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
const index_t channel_blocks = RoundUpDiv4(channels); const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width; const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height; const index_t batch_height_pixels = batch * height;
const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels), const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)}; static_cast<uint32_t>(batch_height_pixels)};
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/negative.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <typename T>
void NegFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("neg");
built_options.emplace("-Dneg=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("neg", kernel_name, built_options);
}
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8};
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}
template struct NegFunctor<DeviceType::OPENCL, float>;
template struct NegFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
...@@ -31,7 +31,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -31,7 +31,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
const Tensor *filter = this->Input(FILTER); const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
functor_(input, filter, bias, output, future); functor_(input, filter, bias, output, future);
return true; return true;
......
...@@ -2,29 +2,29 @@ ...@@ -2,29 +2,29 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/ops/neg.h" #include "mace/ops/cwise.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
void Register_Neg(OperatorRegistry *op_registry) { void Register_CWise(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::CPU) .Device(DeviceType::CPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
NegOp<DeviceType::CPU, float>); CWiseOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
NegOp<DeviceType::OPENCL, float>); CWiseOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
NegOp<DeviceType::OPENCL, half>); CWiseOp<DeviceType::OPENCL, half>);
} }
} // namespace ops } // namespace ops
......
...@@ -2,27 +2,27 @@ ...@@ -2,27 +2,27 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#ifndef MACE_OPS_SCALAR_MATH_H_ #ifndef MACE_OPS_CWISE_H_
#define MACE_OPS_SCALAR_MATH_H_ #define MACE_OPS_CWISE_H_
#include <string> #include <string>
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/scalar_math.h" #include "mace/kernels/cwise.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, class T> template <DeviceType D, class T>
class ScalarMathOp : public Operator<D, T> { class CWiseOp : public Operator<D, T> {
public: public:
ScalarMathOp(const OperatorDef &operator_def, Workspace *ws) CWiseOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
x_(OperatorBase::GetSingleArgument<float>("x", 1.0)), x_(OperatorBase::GetSingleArgument<float>("x", 1.0)),
functor_(static_cast<kernels::ScalarMathType>( functor_(static_cast<kernels::CWiseType>(
OperatorBase::GetSingleArgument<int>( OperatorBase::GetSingleArgument<int>(
"type", static_cast<int>( "type", static_cast<int>(
kernels::ScalarMathType::ADD))), kernels::CWiseType::ADD))),
this->x_) {} this->x_) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
...@@ -40,10 +40,10 @@ class ScalarMathOp : public Operator<D, T> { ...@@ -40,10 +40,10 @@ class ScalarMathOp : public Operator<D, T> {
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
private: private:
kernels::ScalarMathFunctor<D, T> functor_; kernels::CWiseFunctor<D, T> functor_;
}; };
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_SCALAR_MATH_H_ #endif // MACE_OPS_CWISE_H_
...@@ -12,7 +12,7 @@ namespace ops { ...@@ -12,7 +12,7 @@ namespace ops {
namespace test { namespace test {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void ScalarMath(int iters, int batch, int channels, static void CWise(int iters, int batch, int channels,
int height, int width, float x, int type) { int height, int width, float x, int type) {
mace::testing::StopTiming(); mace::testing::StopTiming();
...@@ -24,14 +24,14 @@ static void ScalarMath(int iters, int batch, int channels, ...@@ -24,14 +24,14 @@ static void ScalarMath(int iters, int batch, int channels,
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ScalarMath", "ScalarMathBM") OpDefBuilder("CWise", "CWiseBM")
.Input("InputImage") .Input("InputImage")
.Output("Output") .Output("Output")
.AddIntArg("type", type) .AddIntArg("type", type)
.AddFloatArg("x", x) .AddFloatArg("x", x)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("ScalarMath", "ScalarMathBM") OpDefBuilder("CWise", "CWiseBM")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.AddIntArg("type", type) .AddIntArg("type", type)
...@@ -52,35 +52,41 @@ static void ScalarMath(int iters, int batch, int channels, ...@@ -52,35 +52,41 @@ static void ScalarMath(int iters, int batch, int channels,
net.Sync(); net.Sync();
} }
#define BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \ #define BM_CWISE_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \
static void \ static void \
BM_SCALAR_MATH_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \ BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \ mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ScalarMath<DEVICE, TYPE>(iters, N, C, H, W, X, G); \ CWise<DEVICE, TYPE>(iters, N, C, H, W, X, G); \
} \ } \
BENCHMARK( \ BENCHMARK( \
BM_SCALAR_MATH_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE) BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE)
#define BM_SCALAR_MATH(N, C, H, W, X, G) \ #define BM_CWISE(N, C, H, W, X, G) \
BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, float, CPU); \ BM_CWISE_MACRO(N, C, H, W, X, G, float, CPU); \
BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, float, OPENCL); \ BM_CWISE_MACRO(N, C, H, W, X, G, float, OPENCL); \
BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, half, OPENCL); BM_CWISE_MACRO(N, C, H, W, X, G, half, OPENCL);
BM_SCALAR_MATH(1, 1, 512, 512, 2, 0); BM_CWISE(1, 1, 512, 512, 2, 0);
BM_SCALAR_MATH(1, 3, 128, 128, 2, 1); BM_CWISE(1, 3, 128, 128, 2, 1);
BM_SCALAR_MATH(1, 3, 512, 512, 2, 4); BM_CWISE(1, 3, 512, 512, 2, 4);
BM_SCALAR_MATH(1, 32, 112, 112, 2, 5); BM_CWISE(1, 32, 112, 112, 2, 5);
BM_SCALAR_MATH(1, 64, 256, 256, 3, 0); BM_CWISE(1, 32, 112, 112, 2, 6);
BM_SCALAR_MATH(1, 64, 512, 512, 3, 1); BM_CWISE(1, 32, 112, 112, 2, 7);
BM_SCALAR_MATH(1, 128, 56, 56, 3, 4); BM_CWISE(1, 64, 256, 256, 3, 0);
BM_SCALAR_MATH(1, 128, 256, 256, 3, 5); BM_CWISE(1, 64, 512, 512, 3, 1);
BM_SCALAR_MATH(1, 256, 14, 14, 3, 0); BM_CWISE(1, 128, 56, 56, 3, 4);
BM_SCALAR_MATH(1, 512, 14, 14, 3, 1); BM_CWISE(1, 128, 256, 256, 3, 5);
BM_SCALAR_MATH(1, 1024, 7, 7, 3, 4); BM_CWISE(1, 64, 512, 512, 3, 6);
BM_SCALAR_MATH(32, 1, 256, 256, 3, 5); BM_CWISE(1, 64, 512, 512, 3, 7);
BM_CWISE(1, 256, 14, 14, 3, 0);
BM_CWISE(1, 512, 14, 14, 3, 1);
BM_CWISE(1, 1024, 7, 7, 3, 4);
BM_CWISE(32, 1, 256, 256, 3, 5);
BM_CWISE(32, 1, 256, 256, 3, 6);
BM_CWISE(32, 1, 256, 256, 3, 7);
} // namespace test } // namespace test
} // namespace ops } // namespace ops
......
...@@ -4,17 +4,17 @@ ...@@ -4,17 +4,17 @@
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "../kernels/scalar_math.h" #include "../kernels/cwise.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace test { namespace test {
class ScalarMathOpTest : public OpsTestBase {}; class CWiseOpTest : public OpsTestBase {};
template <DeviceType D> template <DeviceType D>
void Simple(const kernels::ScalarMathType type, void Simple(const kernels::CWiseType type,
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
const std::vector<float> &input0, const std::vector<float> &input0,
const float x, const float x,
...@@ -26,7 +26,7 @@ void Simple(const kernels::ScalarMathType type, ...@@ -26,7 +26,7 @@ void Simple(const kernels::ScalarMathType type,
net.AddInputFromArray<D, float>("Input1", shape, input0); net.AddInputFromArray<D, float>("Input1", shape, input0);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
OpDefBuilder("ScalarMath", "ScalarMathTest") OpDefBuilder("CWise", "CWiseTest")
.Input("Input1") .Input("Input1")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x) .AddFloatArg("x", x)
...@@ -38,7 +38,7 @@ void Simple(const kernels::ScalarMathType type, ...@@ -38,7 +38,7 @@ void Simple(const kernels::ScalarMathType type,
} else { } else {
BufferToImage<D, half>(&net, "Input1", "InputImg1", BufferToImage<D, half>(&net, "Input1", "InputImg1",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ScalarMath", "ScalarMathTest") OpDefBuilder("CWise", "CWiseTest")
.Input("InputImg1") .Input("InputImg1")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x) .AddFloatArg("x", x)
...@@ -57,36 +57,48 @@ void Simple(const kernels::ScalarMathType type, ...@@ -57,36 +57,48 @@ void Simple(const kernels::ScalarMathType type,
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-3); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-3);
} }
TEST_F(ScalarMathOpTest, CPUSimple) { TEST_F(CWiseOpTest, CPUSimple) {
Simple<DeviceType::CPU>(kernels::ScalarMathType::MUL, {1, 1, 2, 3}, Simple<DeviceType::CPU>(kernels::CWiseType::MUL, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6});
Simple<DeviceType::CPU>(kernels::ScalarMathType::ADD, {1, 1, 2, 3}, Simple<DeviceType::CPU>(kernels::CWiseType::ADD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8});
Simple<DeviceType::CPU>(kernels::ScalarMathType::DIV, {1, 1, 2, 3}, Simple<DeviceType::CPU>(kernels::CWiseType::DIV, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60});
Simple<DeviceType::CPU>(kernels::ScalarMathType::SUB, {1, 1, 2, 3}, Simple<DeviceType::CPU>(kernels::CWiseType::SUB, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4});
Simple<DeviceType::CPU>(kernels::CWiseType::NEG, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6});
Simple<DeviceType::CPU>(kernels::CWiseType::ABS, {1, 1, 2, 3},
{1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6});
} }
TEST_F(ScalarMathOpTest, GPUSimple) { TEST_F(CWiseOpTest, GPUSimple) {
Simple<DeviceType::OPENCL>(kernels::ScalarMathType::MUL, {1, 1, 2, 3}, Simple<DeviceType::OPENCL>(kernels::CWiseType::MUL, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6});
Simple<DeviceType::OPENCL>(kernels::ScalarMathType::ADD, {1, 1, 2, 3}, Simple<DeviceType::OPENCL>(kernels::CWiseType::ADD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8});
Simple<DeviceType::OPENCL>(kernels::ScalarMathType::DIV, {1, 1, 2, 3}, Simple<DeviceType::OPENCL>(kernels::CWiseType::DIV, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60});
Simple<DeviceType::OPENCL>(kernels::ScalarMathType::SUB, {1, 1, 2, 3}, Simple<DeviceType::OPENCL>(kernels::CWiseType::SUB, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4});
Simple<DeviceType::OPENCL>(kernels::CWiseType::NEG, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6});
Simple<DeviceType::OPENCL>(kernels::CWiseType::ABS, {1, 1, 2, 3},
{1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6});
} }
template <DeviceType D, typename T> template <DeviceType D, typename T>
void RandomTest(const kernels::ScalarMathType type, void RandomTest(const kernels::CWiseType type,
const std::vector<index_t> &shape) { const std::vector<index_t> &shape) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
srand(time(NULL)); srand(time(NULL));
...@@ -97,7 +109,7 @@ void RandomTest(const kernels::ScalarMathType type, ...@@ -97,7 +109,7 @@ void RandomTest(const kernels::ScalarMathType type,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input1", shape); net.AddRandomInput<D, float>("Input1", shape);
OpDefBuilder("ScalarMath", "ScalarMathTest") OpDefBuilder("CWise", "CWiseTest")
.Input("Input1") .Input("Input1")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 1.2) .AddFloatArg("x", 1.2)
...@@ -110,7 +122,7 @@ void RandomTest(const kernels::ScalarMathType type, ...@@ -110,7 +122,7 @@ void RandomTest(const kernels::ScalarMathType type,
BufferToImage<D, T>(&net, "Input1", "InputImg1", BufferToImage<D, T>(&net, "Input1", "InputImg1",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ScalarMath", "ScalarMathTest") OpDefBuilder("CWise", "CWiseTest")
.Input("InputImg1") .Input("InputImg1")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 1.2) .AddFloatArg("x", 1.2)
...@@ -133,25 +145,29 @@ void RandomTest(const kernels::ScalarMathType type, ...@@ -133,25 +145,29 @@ void RandomTest(const kernels::ScalarMathType type,
} }
} }
TEST_F(ScalarMathOpTest, OPENCLRandomFloat) { TEST_F(CWiseOpTest, OPENCLRandomFloat) {
RandomTest<DeviceType::OPENCL, float>(kernels::ScalarMathType::MUL, RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::MUL,
{3, 23, 37, 19}); {3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::ScalarMathType::ADD, RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::ADD,
{13, 32, 32, 64}); {13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::ScalarMathType::SUB, RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::SUB,
{3, 32, 32, 64}); {3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::ScalarMathType::DIV, RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::DIV,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::NEG,
{13, 32, 32, 64}); {13, 32, 32, 64});
} }
TEST_F(ScalarMathOpTest, OPENCLRandomHalf) { TEST_F(CWiseOpTest, OPENCLRandomHalf) {
RandomTest<DeviceType::OPENCL, half>(kernels::ScalarMathType::MUL, RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::MUL,
{3, 23, 37, 19}); {3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::ScalarMathType::ADD, RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::ADD,
{13, 32, 32, 64}); {13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::ScalarMathType::SUB, RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::SUB,
{3, 32, 32, 64}); {3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::ScalarMathType::DIV, RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::DIV,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::NEG,
{13, 32, 32, 64}); {13, 32, 32, 64});
} }
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_NEG_H_
#define MACE_OPS_NEG_H_
#include <string>
#include "mace/core/operator.h"
#include "mace/kernels/negative.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class NegOp : public Operator<D, T> {
public:
NegOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_() {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(0);
Tensor *output_tensor = this->outputs_[0];
output_tensor->ResizeLike(input_tensor);
functor_(input_tensor, output_tensor, future);
return true;
}
private:
kernels::NegFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_NEGATIVE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void Neg(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Neg", "NegBM")
.Input("InputImage")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Neg", "NegBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_NEG_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Neg<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_NEG(N, C, H, W) \
BM_NEG_MACRO(N, C, H, W, float, CPU); \
BM_NEG_MACRO(N, C, H, W, float, OPENCL); \
BM_NEG_MACRO(N, C, H, W, half, OPENCL);
BM_NEG(1, 1, 512, 512);
BM_NEG(1, 3, 128, 128);
BM_NEG(1, 3, 512, 512);
BM_NEG(1, 32, 112, 112);
BM_NEG(1, 64, 256, 256);
BM_NEG(1, 64, 512, 512);
BM_NEG(1, 128, 56, 56);
BM_NEG(1, 128, 256, 256);
BM_NEG(1, 256, 14, 14);
BM_NEG(1, 512, 14, 14);
BM_NEG(1, 1024, 7, 7);
BM_NEG(32, 1, 256, 256);
BM_NEG(32, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class NegOpTest : public OpsTestBase {};
template <DeviceType D>
void NegSimple() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {1, 6, 2, 1},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Neg", "NegTest")
.Input("InputImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Neg", "NegTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check
auto expected = CreateTensor<float>(
{1, 6, 2, 1},
{-5, -5, -7, -7, -9, -9, -11, -11, -13, -13, -15, -15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-8);
}
TEST_F(NegOpTest, NegSimpleCPU) { NegSimple<DeviceType::CPU>(); }
TEST_F(NegOpTest, NegSimpleOPENCL) {
NegSimple<DeviceType::OPENCL>();
}
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/scalar_math.h"
namespace mace {
namespace ops {
void Register_ScalarMath(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ScalarMathOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
ScalarMathOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
ScalarMathOp<DeviceType::OPENCL, half>);
}
} // namespace ops
} // namespace mace
...@@ -19,14 +19,19 @@ pooling_type_mode = { ...@@ -19,14 +19,19 @@ pooling_type_mode = {
'MaxPool': 2 'MaxPool': 2
} }
# the order should be the same as eltwise type's order # the order should be the same as
# eltwise type's in mace/kernels/eltwise.h
# and also cwise type's in mace/kernels/cwise.h
# cuz these math ops should have compatible with "EltWise" and "CWise"
math_type_mode = { math_type_mode = {
'MUL': 0, 'MUL': 0,
'ADD': 1, 'ADD': 1,
'MAX': 2, 'MAX': 2,
'MIN': 3, 'MIN': 3,
'SUB': 4, 'SUB': 4,
'DIV': 5 'DIV': 5,
'NEG': 6,
'ABS': 7
} }
buffer_type_map = { buffer_type_map = {
...@@ -632,18 +637,6 @@ class TFConverter(object): ...@@ -632,18 +637,6 @@ class TFConverter(object):
self.add_output_shape(op.outputs, op_def) self.add_output_shape(op.outputs, op_def)
self.resolved_ops[op.name] = 1 self.resolved_ops[op.name] = 1
self.unused_tensor.add(get_input_tensor(op, 1).name) self.unused_tensor.add(get_input_tensor(op, 1).name)
def convert_neg(self, op):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
op_def.type = "Neg"
op_def.input.extend([input.name for input in op.inputs])
op_def.output.extend([output.name for output in op.outputs])
self.add_output_shape(op.outputs, op_def)
self.resolved_ops[op.name] = 1
def convert_math(self, op, math_type): def convert_math(self, op, math_type):
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
...@@ -651,24 +644,31 @@ class TFConverter(object): ...@@ -651,24 +644,31 @@ class TFConverter(object):
arg.name = 'T' arg.name = 'T'
arg.i = self.dt arg.i = self.dt
op_def.name = op.name op_def.name = op.name
input_tensor0 = get_input_tensor(op, 0)
input_tensor1 = get_input_tensor(op, 1) if len(op.inputs) == 1:
op_def.type = "CWise"
if input_tensor0.shape == input_tensor1.shape:
op_def.type = "Eltwise"
op_def.input.extend([input.name for input in op.inputs]) op_def.input.extend([input.name for input in op.inputs])
else:
op_def.type = "ScalarMath"
x_value = 0
if len(input_tensor1.shape)==4:
op_def.input.extend([op.inputs[1].name])
x_value = get_input_tensor(op, 0).eval().astype(np.float32)
else:
op_def.input.extend([op.inputs[0].name])
x_value = get_input_tensor(op, 1).eval().astype(np.float32)
x_arg = op_def.arg.add() x_arg = op_def.arg.add()
x_arg.name = 'x' x_arg.name = 'x'
x_arg.f = x_value x_arg.f = 0
elif len(op.inputs) >= 2:
input_tensor0 = get_input_tensor(op, 0)
input_tensor1 = get_input_tensor(op, 1)
if input_tensor0.shape == input_tensor1.shape:
op_def.type = "Eltwise"
op_def.input.extend([input.name for input in op.inputs])
else:
op_def.type = "CWise"
x_value = 0
if len(input_tensor1.shape)==4:
op_def.input.extend([op.inputs[1].name])
x_value = get_input_tensor(op, 0).eval().astype(np.float32)
else:
op_def.input.extend([op.inputs[0].name])
x_value = get_input_tensor(op, 1).eval().astype(np.float32)
x_arg = op_def.arg.add()
x_arg.name = 'x'
x_arg.f = x_value
type_arg = op_def.arg.add() type_arg = op_def.arg.add()
type_arg.name = 'type' type_arg.name = 'type'
type_arg.i = math_type_mode[math_type] type_arg.i = math_type_mode[math_type]
...@@ -919,15 +919,15 @@ class TFConverter(object): ...@@ -919,15 +919,15 @@ class TFConverter(object):
elif op.type == 'BatchToSpaceND': elif op.type == 'BatchToSpaceND':
self.convert_space_to_batch(op, True) self.convert_space_to_batch(op, True)
elif op.type == 'DepthToSpace': elif op.type == 'DepthToSpace':
self.convert_depth_to_space(op, True) self.convert_depth_to_space(op, True)
elif op.type == 'SpaceToDepth': elif op.type == 'SpaceToDepth':
self.convert_depth_to_space(op, False) self.convert_depth_to_space(op, False)
elif op.type == 'Neg': elif op.type in ['Neg', 'neg', 'Negative', 'negative']:
self.convert_neg(op) self.convert_math(op, 'NEG')
elif op.type == 'Mul': elif op.type == 'Mul':
self.convert_math(op, 'MUL') self.convert_math(op, 'MUL')
elif op.type == 'Sub': elif op.type == 'Sub':
self.convert_math(op, 'SUB') self.convert_math(op, 'SUB')
elif self.is_softmax(op): elif self.is_softmax(op):
self.convert_softmax(op) self.convert_softmax(op)
elif op.type in ['Relu', 'Sigmoid', 'Tanh']: elif op.type in ['Relu', 'Sigmoid', 'Tanh']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册