提交 6e08809b 编写于 作者: L Liangliang He

Merge branch 'bias_add_op' into 'master'

add bias_add op

See merge request !144
......@@ -141,6 +141,8 @@ const std::map<std::string, std::string>
OpenCLRuntime::program_map_ = {
{"addn", "addn.cl"},
{"batch_norm", "batch_norm.cl"},
{"bias_add", "bias_add.cl"},
{"buffer_to_image", "buffer_to_image.cl"},
{"conv_2d", "conv_2d.cl"},
{"conv_2d_1x1", "conv_2d_1x1.cl"},
{"conv_2d_3x3", "conv_2d_3x3.cl"},
......@@ -150,7 +152,6 @@ const std::map<std::string, std::string>
{"concat", "concat.cl"},
{"resize_bilinear", "resize_bilinear.cl"},
{"space_to_batch", "space_to_batch.cl"},
{"buffer_to_image", "buffer_to_image.cl"},
};
void OpenCLRuntime::BuildProgram(const std::string &program_file_name,
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_BIAS_ADD_H_
#define MACE_KERNELS_BIAS_ADD_H_
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct BiasAddFunctor {
void operator()(const Tensor *input,
const Tensor *bias,
Tensor *output) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *bias_ptr = bias->data<T>();
T *output_ptr = output->mutable_data<T>();
index_t pos = 0;
#pragma omp parallel for
for (index_t n = 0; n < batch; ++n) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
}
}
}
}
};
/*
template <>
void BiasAddFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *bias,
Tensor *output);
*/
template <typename T>
struct BiasAddFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *input,
const Tensor *bias,
Tensor *output);
};
} // namepsace kernels
} // namespace mace
#endif // MACE_KERNELS_BIAS_ADD_H_
......@@ -44,11 +44,11 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(scale->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(offset->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(scale->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(offset->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/bias_add.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <typename T>
void BiasAddFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *bias,
Tensor *output) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto bias_kernel = runtime->BuildKernel("bias_add", "bias_add", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bias_kernel);
const std::vector<uint32_t> lws = {1, kwg_size, 1};
uint32_t idx = 0;
bias_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
bias_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
bias_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bias_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS);
}
template
struct BiasAddFunctor<DeviceType::OPENCL, float>;
template
struct BiasAddFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
#include <common.h>
// Supported data types: half/float
__kernel void bias_add(__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width = get_global_size(1);
const int pos = ch_blk * width + w;
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 bias_value = READ_IMAGET(bias, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 out = in + bias_value;
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/bias_add.h"
namespace mace {
REGISTER_CPU_OPERATOR(OpKeyBuilder("BiasAdd")
.TypeConstraint<float>("T")
.Build(),
BiasAddOp<DeviceType::CPU, float>);
/*
#if __ARM_NEON
REGISTER_NEON_OPERATOR(OpKeyBuilder("BiasAdd")
.TypeConstraint<float>("T")
.Build(),
BiasAddOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
*/
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BiasAdd")
.TypeConstraint<float>("T")
.Build(),
BiasAddOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BiasAdd")
.TypeConstraint<half>("T")
.Build(),
BiasAddOp<DeviceType::OPENCL, half>);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_BIAS_ADD_H_
#define MACE_BIAS_ADD_H_
#include "mace/core/operator.h"
#include "mace/kernels/bias_add.h"
namespace mace {
template <DeviceType D, class T>
class BiasAddOp : public Operator<D, T> {
public:
BiasAddOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), functor_() {}
bool Run() override {
const Tensor *input = this->Input(INPUT);
const Tensor *bias = this->Input(BIAS);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, bias, output);
return true;
}
private:
kernels::BiasAddFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, BIAS);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_BIAS_ADD_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void BiasAdd(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
net.AddRandomInput<D, T>("Bias", {channels}, true);
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("InputImage")
.Input("BiasImage")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
else {
OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_BIAS_ADD_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BiasAdd<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_BIAS_ADD_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BIAS_ADD(N, C, H, W, TYPE) \
BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, CPU); \
BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, OPENCL);
BM_BIAS_ADD(1, 1, 512, 512, float);
BM_BIAS_ADD(1, 3, 128, 128, float);
BM_BIAS_ADD(1, 3, 512, 512, float);
BM_BIAS_ADD(1, 32, 112, 112, float);
BM_BIAS_ADD(1, 64, 256, 256, float);
BM_BIAS_ADD(1, 64, 512, 512, float);
BM_BIAS_ADD(1, 128, 56, 56, float);
BM_BIAS_ADD(1, 128, 256, 256, float);
BM_BIAS_ADD(1, 256, 14, 14, float);
BM_BIAS_ADD(1, 512, 14, 14, float);
BM_BIAS_ADD(1, 1024, 7, 7, float);
BM_BIAS_ADD(32, 1, 256, 256, float);
BM_BIAS_ADD(32, 3, 256, 256, float);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
class BiasAddOpTest : public OpsTestBase {};
template <DeviceType D>
void BiasAddSimple() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {1, 6, 2, 1},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
net.AddInputFromArray<D, float>("Bias", {1}, {0.5f});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D, float>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputImage")
.Input("BiasImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check
auto expected =
CreateTensor<float>({1, 6, 2, 1}, {5.5, 5.5, 7.5, 7.5, 9.5, 9.5, 11.5,
11.5, 13.5, 13.5, 15.5, 15.5});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(BiasAddOpTest, BiasAddSimpleCPU) {
BiasAddSimple<DeviceType::CPU>();
}
TEST_F(BiasAddOpTest, BiasAddSimpleOPENCL) {
BiasAddSimple<DeviceType::OPENCL>();
}
TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64 + rand() % 50;
index_t width = 64 + rand() % 50;
// Construct graph
auto &net = test_net();
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>("Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true);
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, float>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputImage")
.Input("BiasImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
net.Sync();
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-2);
}
TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103 + rand() % 100;
index_t width = 113 + rand() % 100;
// Construct graph
auto &net = test_net();
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>("Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true);
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, float>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputImage")
.Input("BiasImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
net.Sync();
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-2);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册