提交 07927d5a 编写于 作者: L liuqi

Add FC op.

上级 8d8bbfcf
...@@ -82,6 +82,7 @@ extern void Register_WinogradTransform(OperatorRegistry *op_registry); ...@@ -82,6 +82,7 @@ extern void Register_WinogradTransform(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() { OperatorRegistry::OperatorRegistry() {
Register_Activation(this); Register_Activation(this);
...@@ -107,6 +108,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -107,6 +108,7 @@ OperatorRegistry::OperatorRegistry() {
Register_WinogradInverseTransform(this); Register_WinogradInverseTransform(this);
Register_Reshape(this); Register_Reshape(this);
Register_Eltwise(this); Register_Eltwise(this);
Register_FullyConnected(this);
} }
} // namespace mace } // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_FULLY_CONNECTED_H_
#define MACE_KERNELS_FULLY_CONNECTED_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
struct FullyConnectedBase {
FullyConnectedBase(const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template <DeviceType D, typename T>
struct FullyConnectedFunctor : FullyConnectedBase{
FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha) :
FullyConnectedBase(activation, relux_max_limit, prelu_alpha){}
void operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0);
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_weight(weight);
Tensor::MappingGuard guard_bias(bias);
Tensor::MappingGuard guard_output(output);
const T *input_ptr = input->data<T>();
const T *weight_ptr = weight->data<T>();
const T *bias_ptr = bias == nullptr ? nullptr : bias->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for collapse(2)
for (int i = 0; i < N; ++i) {
for (int out_idx = 0; out_idx < output_size; ++out_idx) {
T sum = 0;
if (bias_ptr != nullptr) sum = bias_ptr[out_idx];
index_t input_offset = i * input_size;
index_t weight_offset = out_idx * input_size;
for (int in_idx = 0; in_idx < input_size; ++in_idx) {
sum += input_ptr[input_offset] * weight_ptr[weight_offset];
input_offset++;
weight_offset++;
}
output_ptr[i * output_size + out_idx] = sum;
}
}
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
}
};
template <typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase{
FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha) :
FullyConnectedBase(activation, relux_max_limit, prelu_alpha){}
void operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_FULLY_CONNECTED_H_
...@@ -54,7 +54,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -54,7 +54,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
tuning_key_prefix = "sigmoid_opencl_kernel_"; tuning_key_prefix = "sigmoid_opencl_kernel_";
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation_; LOG(FATAL) << "Unknown activation type: " << activation_;
} }
kernel_ = kernel_ =
......
...@@ -59,7 +59,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -59,7 +59,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
case SIGMOID: case SIGMOID:
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation_; LOG(FATAL) << "Unknown activation type: " << activation_;
} }
......
...@@ -48,6 +48,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer, ...@@ -48,6 +48,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
kernel_name = i2b_ ? "arg_image_to_buffer" : "arg_buffer_to_image"; kernel_name = i2b_ ? "arg_image_to_buffer" : "arg_buffer_to_image";
break; break;
case IN_OUT_HEIGHT: case IN_OUT_HEIGHT:
case WEIGHT_HEIGHT:
kernel_name = i2b_ ? "in_out_height_image_to_buffer" : "in_out_height_buffer_to_image"; kernel_name = i2b_ ? "in_out_height_image_to_buffer" : "in_out_height_buffer_to_image";
break; break;
case IN_OUT_WIDTH: case IN_OUT_WIDTH:
...@@ -80,6 +81,10 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer, ...@@ -80,6 +81,10 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
b2f_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(buffer->buffer()))); b2f_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(buffer->buffer())));
if (type == ARGUMENT) { if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else if(type == WEIGHT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, 1);
} else { } else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
......
...@@ -9,8 +9,8 @@ __kernel void batch_norm(__read_only image2d_t input, ...@@ -9,8 +9,8 @@ __kernel void batch_norm(__read_only image2d_t input,
__private const float epsilon, __private const float epsilon,
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha) { __private const float prelu_alpha) {
const int ch_blk = get_global_id(0); const int ch_blk = get_global_id(0);
const int w = get_global_id(1); const int w = get_global_id(1);
const int hb = get_global_id(2); const int hb = get_global_id(2);
......
...@@ -22,8 +22,8 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | ...@@ -22,8 +22,8 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
inline DATA_TYPE4 do_activation(DATA_TYPE4 in, inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha) { __private const float prelu_alpha) {
DATA_TYPE4 out; DATA_TYPE4 out;
#ifdef USE_RELU #ifdef USE_RELU
out = fmax(in, 0); out = fmax(in, 0);
......
...@@ -6,8 +6,8 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ ...@@ -6,8 +6,8 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t bias, /* cout%4 * cout/4 */ __read_only image2d_t bias, /* cout%4 * cout/4 */
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha, __private const float prelu_alpha,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
......
...@@ -6,8 +6,8 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -6,8 +6,8 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__read_only image2d_t bias, /* cout%4 * cout/4 */ __read_only image2d_t bias, /* cout%4 * cout/4 */
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha, __private const float prelu_alpha,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
......
...@@ -6,8 +6,8 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -6,8 +6,8 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__read_only image2d_t bias, /* cout%4 * cout/4 */ __read_only image2d_t bias, /* cout%4 * cout/4 */
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha, __private const float prelu_alpha,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
......
...@@ -7,8 +7,8 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h ...@@ -7,8 +7,8 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
__read_only image2d_t bias, /* cout%4 * cout/4 */ __read_only image2d_t bias, /* cout%4 * cout/4 */
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha, __private const float prelu_alpha,
__private const short in_height, __private const short in_height,
__private const short in_width, __private const short in_width,
__private const short in_ch_blks, __private const short in_ch_blks,
......
#include <common.h>
// output = weight * input + bias
__kernel void fc(__read_only image2d_t input,
__read_only image2d_t weight,
#ifdef BIAS
__read_only image2d_t bias,
#endif
__write_only image2d_t output,
__private const int input_height,
__private const int input_width,
__private const int input_channel,
__private const float relux_max_limit,
__private const float prelu_alpha) {
const int batch_idx = get_global_id(0);
const int out_blk_idx = get_global_id(1);
const int input_chan_blk = (input_channel + 3) >> 2;
float4 input_value;
float4 w0, w1, w2, w3;
#ifdef BIAS
DATA_TYPE4 result = READ_IMAGET(bias, SAMPLER, (int2)(out_blk_idx, 0));
#else
DATA_TYPE4 result = (DATA_TYPE4)(0, 0, 0, 0);
#endif
int2 input_coord = (int2)(0, mul24(batch_idx, input_height));
int weight_x = 0;
for (short h_idx = 0; h_idx < input_height; ++h_idx) {
for (short w_idx = 0; w_idx < input_width; ++w_idx) {
input_coord.x = w_idx;
weight_x = (h_idx * input_width + w_idx) * input_channel;
#pragma unroll
for (short chan_idx = 0; chan_idx < input_chan_blk; ++chan_idx) {
input_value = READ_IMAGET(input, SAMPLER, input_coord);
w0 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx));
result = mad(input_value.x, w0, result);
result = mad(input_value.y, w1, result);
result = mad(input_value.z, w2, result);
result = mad(input_value.w, w3, result);
input_coord.x += input_width;
}
}
input_coord.y++;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit, prelu_alpha);
#endif
WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result);
}
...@@ -115,8 +115,8 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, ...@@ -115,8 +115,8 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
__private const int out_width, __private const int out_width,
__private const int round_hw, __private const int round_hw,
__private const int round_w, __private const int round_w,
__private const DATA_TYPE relux_max_limit, __private const float relux_max_limit,
__private const DATA_TYPE prelu_alpha) { __private const float prelu_alpha) {
const int width_idx = get_global_id(0); const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1); const int height_idx = get_global_id(1);
const int out_channel = get_global_size(1); const int out_channel = get_global_size(1);
......
...@@ -66,7 +66,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -66,7 +66,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
case SIGMOID: case SIGMOID:
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation; LOG(FATAL) << "Unknown activation type: " << activation;
} }
......
...@@ -61,7 +61,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -61,7 +61,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
case SIGMOID: case SIGMOID:
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation; LOG(FATAL) << "Unknown activation type: " << activation;
} }
......
...@@ -61,7 +61,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -61,7 +61,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
case SIGMOID: case SIGMOID:
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation; LOG(FATAL) << "Unknown activation type: " << activation;
} }
......
...@@ -78,7 +78,7 @@ void DepthwiseConv2d(cl::Kernel *kernel, ...@@ -78,7 +78,7 @@ void DepthwiseConv2d(cl::Kernel *kernel,
case SIGMOID: case SIGMOID:
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation; LOG(FATAL) << "Unknown activation type: " << activation;
} }
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/fully_connected.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template<typename T>
void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
const index_t batch = output->dim(0);
const index_t output_size = output->dim(3);
const index_t output_blocks = RoundUpDiv4(output_size);
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fc");
built_options.emplace("-Dfc=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
kernel_ = runtime->BuildKernel("fc", kernel_name, built_options);
uint32_t idx = 0;
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(input->buffer())));
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(weight->buffer())));
if (bias != nullptr) {
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(bias->buffer())));
}
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
kernel_.setArg(idx++, static_cast<int>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int>(input->dim(3)));
// FIXME handle flexable data type: half not supported
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
}
const uint32_t gws[2] = {
static_cast<uint32_t>(batch),
static_cast<uint32_t>(output_blocks),
};
const std::vector<uint32_t> lws = {16, 64, 1};
std::stringstream ss;
ss << "fc_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
};
template
struct FullyConnectedFunctor<DeviceType::OPENCL, float>;
template
struct FullyConnectedFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
...@@ -73,6 +73,15 @@ void CalInOutWidthImageShape(const std::vector<index_t> &shape, /* NHWC */ ...@@ -73,6 +73,15 @@ void CalInOutWidthImageShape(const std::vector<index_t> &shape, /* NHWC */
image_shape[1] = shape[0] * shape[1]; image_shape[1] = shape[0] * shape[1];
} }
// [W, (H + 3) / 4]
void CalWeightHeightImageShape(const std::vector<index_t> &shape, /* HW */
std::vector<size_t> &image_shape) {
MACE_CHECK(shape.size() == 2);
image_shape.resize(2);
image_shape[0] = shape[1];
image_shape[1] = RoundUpDiv4(shape[0]);
}
void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
const BufferType type, const BufferType type,
std::vector<size_t> &image_shape) { std::vector<size_t> &image_shape) {
...@@ -98,6 +107,9 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */ ...@@ -98,6 +107,9 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
case WINOGRAD_FILTER: case WINOGRAD_FILTER:
CalWinogradFilterImageShape(shape, image_shape); CalWinogradFilterImageShape(shape, image_shape);
break; break;
case WEIGHT_HEIGHT:
CalWeightHeightImageShape(shape, image_shape);
break;
default: default:
LOG(FATAL) << "Mace not supported yet."; LOG(FATAL) << "Mace not supported yet.";
} }
......
...@@ -24,6 +24,7 @@ enum BufferType { ...@@ -24,6 +24,7 @@ enum BufferType {
IN_OUT_WIDTH = 4, IN_OUT_WIDTH = 4,
WINOGRAD_FILTER = 5, WINOGRAD_FILTER = 5,
DW_CONV2D_FILTER = 6, DW_CONV2D_FILTER = 6,
WEIGHT_HEIGHT = 7,
}; };
void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
......
...@@ -101,7 +101,7 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te ...@@ -101,7 +101,7 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
case SIGMOID: case SIGMOID:
built_options.emplace("-DUSE_SIGMOID"); built_options.emplace("-DUSE_SIGMOID");
break; break;
defeult: default:
LOG(FATAL) << "Unknown activation type: " << activation_; LOG(FATAL) << "Unknown activation type: " << activation_;
} }
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/fully_connected.h"
namespace mace {
void Register_FullyConnected(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
FullyConnectedOp<DeviceType::OPENCL, half>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_FULLY_CONNECTED_H_
#define MACE_OPS_FULLY_CONNECTED_H_
#include "mace/core/operator.h"
#include "mace/kernels/fully_connected.h"
namespace mace {
template <DeviceType D, class T>
class FullyConnectedOp : public Operator<D, T> {
public:
FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *weight = this->Input(WEIGHT);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT);
const index_t input_size = input->dim(1) * input->dim(2) * input->dim(3);
MACE_CHECK(input_size == weight->dim(1) && weight->dim(0) == bias->dim(0))
<< "The size of Input, Weight and Bias don't match.";
functor_(input, weight, bias, output, future);
return true;
}
private:
kernels::FullyConnectedFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, WEIGHT, BIAS);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_FULLY_CONNECTED_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void FCBenchmark(
int iters, int batch, int height, int width, int channel, int out_channel) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channel});
net.AddRandomInput<D, float>("Weight", {out_channel, height * width * channel});
net.AddRandomInput<D, float>("Bias", {out_channel});
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_HEIGHT);
BufferToImage<D, T>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_FC_MACRO(N, H, W, C, OC, TYPE, DEVICE) \
static void BM_FC_##N##_##H##_##W##_##C##_##OC##_##TYPE##_##DEVICE(int iters) { \
const int64_t macc = static_cast<int64_t>(iters) * N * C * H * W * OC + OC; \
const int64_t tot = static_cast<int64_t>(iters) * (N + OC) * C * H * W + OC; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
FCBenchmark<DEVICE, TYPE>(iters, N, H, W, C, OC); \
} \
BENCHMARK(BM_FC_##N##_##H##_##W##_##C##_##OC##_##TYPE##_##DEVICE)
#define BM_FC(N, H, W, C, OC) \
BM_FC_MACRO(N, H, W, C, OC, float, CPU); \
BM_FC_MACRO(N, H, W, C, OC, float, OPENCL); \
BM_FC_MACRO(N, H, W, C, OC, half, OPENCL);
BM_FC(1, 16, 16, 32, 32);
BM_FC(1, 8, 8, 32, 1000);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <fstream>
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
class FullyConnectedOpTest : public OpsTestBase {};
template<DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const std::vector<index_t> &weight_shape,
const std::vector<float> &weight_value,
const std::vector<index_t> &bias_shape,
const std::vector<float> &bias_value,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input_value);
net.AddInputFromArray<D, float>("Weight", weight_shape, weight_value);
net.AddInputFromArray<D, float>("Bias", bias_shape, bias_value);
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_HEIGHT);
BufferToImage<D, float>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check
auto expected =
CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(FullyConnectedOpTest, SimpleCPU) {
Simple<DeviceType::CPU>({1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8},
{1, 8},
{1, 2, 3, 4, 5, 6, 7, 8},
{1}, {2},
{1, 1, 1, 1}, {206});
Simple<DeviceType::CPU>({1, 1, 2, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3},
{1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::CPU>({1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{5, 6},
{1, 2, 3, 4, 5, 6,
10, 20, 30, 40, 50, 60,
1, 2, 3, 4, 5, 6,
10, 20, 30, 40, 50, 60,
1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5},
{1, 1, 1, 5}, {92, 912, 94, 914, 96});
}
TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) {
Simple<DeviceType::CPU>({2, 1, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8},
{1, 4},
{1, 2, 3, 4},
{1}, {2},
{2, 1, 1, 1}, {32, 72});
}
TEST_F(FullyConnectedOpTest, SimpleOPENCL) {
Simple<DeviceType::OPENCL>({1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8},
{1, 8},
{1, 2, 3, 4, 5, 6, 7, 8},
{1}, {2},
{1, 1, 1, 1}, {206});
Simple<DeviceType::OPENCL>({1, 1, 2, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3},
{1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::OPENCL>({1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{5, 6},
{1, 2, 3, 4, 5, 6,
10, 20, 30, 40, 50, 60,
1, 2, 3, 4, 5, 6,
10, 20, 30, 40, 50, 60,
1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5},
{1, 1, 1, 5}, {92, 912, 94, 914, 96});
}
TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) {
Simple<DeviceType::OPENCL>({2, 1, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8},
{1, 4},
{1, 2, 3, 4},
{1}, {2},
{2, 1, 1, 1}, {32, 72});
}
template<typename T>
void Complex(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"Bias", {out_channel});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::OPENCL, T>(net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_HEIGHT);
BufferToImage<DeviceType::OPENCL, float>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1);
} else {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-3);
}
}
TEST_F(FullyConnectedOpTest, OPENCLAlignedWithoutBatch) {
Complex<float>(1, 16, 16, 32, 16);
Complex<float>(1, 16, 32, 32, 32);
}
TEST_F(FullyConnectedOpTest, OPENCLUnAlignedWithoutBatch) {
Complex<float>(1, 13, 11, 11, 17);
Complex<float>(1, 23, 29, 23, 113);
}
TEST_F(FullyConnectedOpTest, OPENCLUnAlignedWithBatch) {
Complex<float>(16, 11, 13, 23, 17);
Complex<float>(31, 13, 11, 29, 113);
}
TEST_F(FullyConnectedOpTest, OPENCLHalfAlignedWithoutBatch) {
Complex<half>(1, 16, 16, 32, 16);
Complex<half>(1, 16, 32, 32, 32);
}
TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(2, 11, 13, 61, 17);
Complex<half>(16, 13, 12, 31, 113);
Complex<half>(31, 21, 11, 23, 103);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册