提交 478b21f3 编写于 作者: L liuqi

Add pad op.

上级 36f8b360
......@@ -83,6 +83,7 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pad(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
......@@ -119,6 +120,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_GlobalAvgPooling(this);
ops::Register_ImageToBuffer(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
......
#include <common.h>
__kernel void pad(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__write_only image2d_t output,
__private const float constant_value,
__private const int input_height,
__private const int input_width,
__private const int output_height,
__private const int height_padding,
__private const int width_padding) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
const int batch_idx = hb_idx / output_height;
const int height_idx = hb_idx % output_height;
const int input_padded_height = input_height + height_padding;
const int input_padded_width = input_width + width_padding;
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
DATA_TYPE4 data = constant_value;
if ((height_padding <= height_idx && height_idx < input_padded_height) &&
(width_padding <= width_idx && width_idx < input_padded_width)) {
const int in_hb_idx = mad24(batch_idx, input_height,
height_idx - height_padding);
data = READ_IMAGET(input,
SAMPLER,
(int2)(mad24(chan_blk_idx, input_width,
width_idx - width_padding),
in_hb_idx));
}
const int pos = mad24(chan_blk_idx, width, width_idx);
WRITE_IMAGET(output, (int2)(pos, hb_idx), data);
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/pad.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template<typename T>
void PadFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(this->paddings_.size() == (input->dim_size() * 2));
MACE_CHECK((this->paddings_[0] == 0) && (this->paddings_[1] == 0)
&& (this->paddings_[6] == 0) && (this->paddings_[7] == 0))
<< "Mace only support height/width dimension now";
auto input_shape = input->shape();
std::vector<index_t>
output_shape = {input_shape[0] + this->paddings_[0] + this->paddings_[1],
input_shape[1] + this->paddings_[2] + this->paddings_[3],
input_shape[2] + this->paddings_[4] + this->paddings_[5],
input_shape[3] + this->paddings_[6] + this->paddings_[7]};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape);
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
const index_t channels = output->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("pad");
built_options.emplace("-Dpad=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
kernel_error_ = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::OPENCL), 1)));
kernel_error_->Map(nullptr);
*(kernel_error_->mutable_data<char>()) = 0;
kernel_error_->UnMap();
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
kernel_ = runtime->BuildKernel("pad", kernel_name, built_options);
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
if (!IsVecEqual(input_shape_, input->shape())) {
int idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_.setArg(idx++,
*(static_cast<cl::Buffer *>(kernel_error_->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
kernel_.setArg(idx++, gws[2]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, this->constant_value_);
kernel_.setArg(idx++, static_cast<int32_t>(input_shape[1]));
kernel_.setArg(idx++, static_cast<int32_t>(input_shape[2]));
kernel_.setArg(idx++, static_cast<int32_t>(output_shape[1]));
kernel_.setArg(idx++, this->paddings_[2]);
kernel_.setArg(idx++, this->paddings_[4]);
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
std::string tuning_key =
Concat("pad", output->dim(0), output->dim(1), output->dim(2),
output->dim(3));
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
}
template
struct PadFunctor<DeviceType::OPENCL, float>;
template
struct PadFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_PAD_H_
#define MACE_KERNELS_PAD_H_
#include <algorithm>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
struct PadFunctorBase {
PadFunctorBase(const std::vector<int> &paddings,
const float constant_value)
: paddings_(paddings), constant_value_(constant_value) {}
std::vector<int> paddings_;
float constant_value_;
};
template <DeviceType D, typename T>
struct PadFunctor : public PadFunctorBase {
PadFunctor(const std::vector<int> &paddings,
const float constant_value)
: PadFunctorBase(paddings, constant_value) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(this->paddings_.size() == (input->dim_size() * 2));
auto input_shape = input->shape();
output->Resize({input_shape[0] + this->paddings_[0] + this->paddings_[1],
input_shape[1] + this->paddings_[2] + this->paddings_[3],
input_shape[2] + this->paddings_[4] + this->paddings_[5],
input_shape[3] + this->paddings_[6] + this->paddings_[7]});
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
auto input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
std::fill(output_ptr, output_ptr + output->size(), this->constant_value_);
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channel = input->dim(3);
for (index_t b = 0; b < batch; ++b) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
const index_t in_offset = (((b * height + h) * width) + w) * channel;
const index_t out_offset = (((b + this->paddings_[0]) * output->dim(1)
+ (h + this->paddings_[2])) * output->dim(2)
+ (w + this->paddings_[4])) * output->dim(3)
+ this->paddings_[6];
memcpy(output_ptr + out_offset,
input_ptr + in_offset,
channel * sizeof(T));
}
}
}
}
};
template <typename T>
struct PadFunctor<DeviceType::OPENCL, T> : PadFunctorBase {
PadFunctor(const std::vector<int> &paddings,
const float constant_value)
: PadFunctorBase(paddings, constant_value) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_PAD_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/pad.h"
namespace mace {
namespace ops {
void Register_Pad(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
PadOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
PadOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
PadOp<DeviceType::OPENCL, half>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_PAD_H_
#define MACE_OPS_PAD_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/pad.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class PadOp : public Operator<D, T> {
public:
PadOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("paddings"),
OperatorBase::GetSingleArgument<float>("constant_value", 0.0))
{}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(0);
Tensor *output_tensor = this->Output(0);
functor_(input_tensor, output_tensor, future);
return true;
}
private:
kernels::PadFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_PAD_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class PadTest : public OpsTestBase {};
template <DeviceType D>
void Simple() {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRepeatedInput<D, float>("Input", {1, 2, 3, 1}, 2);
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Pad", "PadTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0})
.AddFloatArg("constant_value", 1.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Pad", "PadTest")
.Input("Input")
.Output("Output")
.AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0})
.AddFloatArg("constant_value", 1.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
}
auto output = net.GetTensor("Output");
auto expected = CreateTensor<float>({1, 5, 6, 1},
{
1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 2, 2, 2, 1.0, 1.0,
1.0, 2, 2, 2, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
});
ExpectTensorNear<float>(*expected, *output, 1e-5);
}
TEST_F(PadTest, SimpleCPU) {
Simple<DeviceType::CPU>();
}
TEST_F(PadTest, SimpleGPU) {
Simple<DeviceType::OPENCL>();
}
TEST_F(PadTest, ComplexCPU) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRepeatedInput<DeviceType::CPU, float>("Input", {1, 1, 1, 2}, 2);
OpDefBuilder("Pad", "PadTest")
.Input("Input")
.Output("Output")
.AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1})
.AddFloatArg("constant_value", 1.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected = CreateTensor<float>(
{1, 3, 3, 4},
{
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
});
ExpectTensorNear<float>(*expected, *output, 1e-5);
}
template <typename T>
void Complex(const std::vector<index_t> &input_shape,
const std::vector<int> &paddings) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>("Input", input_shape);
OpDefBuilder("Pad", "PadTest")
.Input("Input")
.Output("Output")
.AddIntsArg("paddings", paddings)
.AddFloatArg("constant_value", 1.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
BufferToImage<DeviceType::OPENCL, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Pad", "PadTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntsArg("paddings", paddings)
.AddFloatArg("constant_value", 1.0)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::OPENCL);
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "OpenCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
auto output = net.GetTensor("OpenCLOutput");
if (DataTypeToEnum<T>::value == DT_HALF) {
ExpectTensorNear<float>(expected, *output, 1e-1);
} else {
ExpectTensorNear<float>(expected, *output, 1e-5);
}
}
TEST_F(PadTest, ComplexFloat) {
Complex<float>({1, 32, 32, 4}, {0, 0, 2, 2, 1, 1, 0, 0});
Complex<float>({1, 31, 37, 16}, {0, 0, 2, 0, 1, 0, 0, 0});
Complex<float>({1, 128, 128, 32}, {0, 0, 0, 1, 0, 2, 0, 0});
}
TEST_F(PadTest, ComplexHalf) {
Complex<half>({1, 32, 32, 4}, {0, 0, 2, 2, 1, 1, 0, 0});
Complex<half>({1, 31, 37, 16}, {0, 0, 2, 0, 1, 0, 0, 0});
Complex<half>({1, 128, 128, 32}, {0, 0, 0, 1, 0, 2, 0, 0});
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册