提交 c1534c96 编写于 作者: 李寅

Merge branch 'slice' into 'master'

Add Slice op And memory reuse support multiple output.

See merge request !280
...@@ -153,7 +153,9 @@ void OperatorDef::CopyFrom(const OperatorDef &from) { ...@@ -153,7 +153,9 @@ void OperatorDef::CopyFrom(const OperatorDef &from) {
output_type_.resize(from_data_type.size()); output_type_.resize(from_data_type.size());
std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin()); std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin());
mem_id_ = from.mem_id(); auto mem_ids = from.mem_id();
mem_id_.resize(mem_ids.size());
std::copy(mem_ids.begin(), mem_ids.end(), mem_id_.begin());
// nnlib // nnlib
node_id_ = from.node_id(); node_id_ = from.node_id();
...@@ -186,13 +188,11 @@ void OperatorDef::set_type(const std::string &type_) { ...@@ -186,13 +188,11 @@ void OperatorDef::set_type(const std::string &type_) {
} }
bool OperatorDef::has_type() const { return (has_bits_ & 0x00000002u) != 0; } bool OperatorDef::has_type() const { return (has_bits_ & 0x00000002u) != 0; }
void OperatorDef::set_has_type() { has_bits_ |= 0x00000002u; } void OperatorDef::set_has_type() { has_bits_ |= 0x00000002u; }
int OperatorDef::mem_id() const { return mem_id_; } const std::vector<int> &OperatorDef::mem_id() const { return mem_id_; }
void OperatorDef::set_mem_id(const int mem_id) { void OperatorDef::set_mem_id(const std::vector<int> &value) {
set_has_mem_id(); mem_id_.resize(value.size());
mem_id_ = mem_id; std::copy(value.begin(), value.end(), mem_id_.begin());
} }
bool OperatorDef::has_mem_id() const { return (has_bits_ & 0x00000004u) != 0; }
void OperatorDef::set_has_mem_id() { has_bits_ |= 0x00000004u; }
uint32_t OperatorDef::node_id() const { return node_id_; } uint32_t OperatorDef::node_id() const { return node_id_; }
void OperatorDef::set_node_id(uint32_t node_id) { node_id_ = node_id; } void OperatorDef::set_node_id(uint32_t node_id) { node_id_ = node_id; }
uint32_t OperatorDef::op_id() const { return op_id_; } uint32_t OperatorDef::op_id() const { return op_id_; }
......
...@@ -83,6 +83,7 @@ extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); ...@@ -83,6 +83,7 @@ extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() { OperatorRegistry::OperatorRegistry() {
Register_Activation(this); Register_Activation(this);
...@@ -109,6 +110,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -109,6 +110,7 @@ OperatorRegistry::OperatorRegistry() {
Register_Reshape(this); Register_Reshape(this);
Register_Eltwise(this); Register_Eltwise(this);
Register_FullyConnected(this); Register_FullyConnected(this);
Register_Slice(this);
} }
} // namespace mace } // namespace mace
...@@ -116,7 +116,7 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) { ...@@ -116,7 +116,7 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
// As DSP may have different data output type for each op, // As DSP may have different data output type for each op,
// we stick to the same concept. // we stick to the same concept.
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
if (op.has_mem_id()) { if (! op.mem_id().empty()){
const DataType op_dtype = static_cast<DataType>( const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT))); op, "T", static_cast<int>(DT_FLOAT)));
...@@ -135,18 +135,20 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) { ...@@ -135,18 +135,20 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
} }
VLOG(3) << "Preallocate image to tensors"; VLOG(3) << "Preallocate image to tensors";
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
if (op.has_mem_id()) { if (!op.mem_id().empty()) {
std::unique_ptr<Tensor> tensor( auto mem_ids = op.mem_id();
new Tensor(preallocated_allocator_.GetBuffer(op.mem_id()), dtype)); int count = mem_ids.size();
tensor->SetSourceOpName(op.name()); for (int i = 0; i < count; ++i) {
VLOG(3) std::unique_ptr<Tensor> tensor
<< "Tensor: " << op.name() << "(" << op.type() << ")" (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), dtype));
<< "; Mem: " << op.mem_id() << "; Image shape: " tensor->SetSourceOpName(op.name());
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())->image_shape()[0] VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" << "; Mem: "
<< ", " << mem_ids[i] << "; Image shape: "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer()) << dynamic_cast<Image *>(tensor->UnderlyingBuffer())->image_shape()[0]
->image_shape()[1]; << ", "
tensor_map_[op.output(0)] = std::move(tensor); << dynamic_cast<Image *>(tensor->UnderlyingBuffer())->image_shape()[1];
tensor_map_[op.output(i)] = std::move(tensor);
}
} }
} }
} }
......
#include <common.h>
__kernel void slice(__read_only image2d_t input,
__private const int chan_blk_offset,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int width = get_global_size(1);
const int hb_idx = get_global_id(2);
DATA_TYPE4 data = READ_IMAGET(input, SAMPLER,
(int2)(mad24(chan_blk_idx + chan_blk_offset,
width, width_idx), hb_idx));
WRITE_IMAGET(output,
(int2)(mad24(chan_blk_idx, width, width_idx), hb_idx), data);
}
...@@ -72,8 +72,6 @@ static void ConcatN(cl::Kernel *kernel, ...@@ -72,8 +72,6 @@ static void ConcatN(cl::Kernel *kernel,
const index_t width = output->dim(2); const index_t width = output->dim(2);
const index_t channel = output->dim(3); const index_t channel = output->dim(3);
const int channel_blk = RoundUpDiv4(channel);
if (kernel->get() == nullptr) { if (kernel->get() == nullptr) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options; std::set<std::string> built_options;
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/slice.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template<typename T>
void SliceFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future) {
const index_t input_channels = input->dim(3);
const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count;
MACE_CHECK(output_channels % 4 == 0)
<< "output channels of slice op must be divisible by 4";
std::vector<index_t> output_shape({input->dim(0), input->dim(1),
input->dim(2), output_channels});
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape);
for (size_t i= 0; i < outputs_count; ++i) {
output_list[i]->ResizeImage(output_shape, image_shape);
}
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice");
built_options.emplace("-Dslice=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE="
+ DtToCLCMDDt(DataTypeToEnum<T>::value));
kernel_ = runtime->BuildKernel("slice", kernel_name, built_options);
}
const index_t channel_blk = RoundUpDiv4(output_channels);
const uint32_t gws[3] = {
static_cast<uint32_t>(channel_blk),
static_cast<uint32_t>(input->dim(2)),
static_cast<uint32_t>(input->dim(0) * input->dim(1)),
};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "slice_opencl_kernel_"
<< input->dim(0) << "_"
<< input->dim(1) << "_"
<< input->dim(2) << "_"
<< input_channels << "_"
<< outputs_count;
for (int i = 0; i < outputs_count; ++i) {
uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int32_t>(channel_blk * i));
kernel_.setArg(idx++, *(output_list[i]->opencl_image()));
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
}
template
struct SliceFunctor<DeviceType::OPENCL, float>;
template
struct SliceFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_SLICE_H_
#define MACE_KERNELS_SLICE_H_
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/core/types.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
struct SliceFunctor {
void operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future) {
const index_t outer_size = input->dim(0) * input->dim(1) * input->dim(2);
const index_t input_channels = input->dim(3);
const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count;
std::vector<T *> output_ptrs(output_list.size(), nullptr);
std::vector<index_t> output_shape({input->dim(0), input->dim(1),
input->dim(2), output_channels});
for (size_t i= 0; i < outputs_count; ++i) {
output_list[i]->Resize(output_shape);
output_ptrs[i] = output_list[i]->mutable_data<T>();
}
const T *input_ptr = input->data<T>();
#pragma omp parallel for
for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
int input_idx = outer_idx * input_channels;
int output_idx = outer_idx * output_channels;
for (size_t i = 0; i < outputs_count; ++i) {
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
memcpy(output_ptrs[i]+output_idx, input_ptr+input_idx,
output_channels * sizeof(T));
} else {
for (index_t k = 0; k < output_channels; ++k) {
*(output_ptrs[i] + output_idx + k) = *(input_ptr + input_idx + k);
}
}
input_idx += output_channels;
}
}
}
};
template<typename T>
struct SliceFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namepsace kernels
} // namespace mace
#endif // MACE_KERNELS_SLICE_H_
...@@ -12,11 +12,6 @@ void Register_Concat(OperatorRegistry *op_registry) { ...@@ -12,11 +12,6 @@ void Register_Concat(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
ConcatOp<DeviceType::CPU, float>); ConcatOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::CPU)
.TypeConstraint<half>("T")
.Build(),
ConcatOp<DeviceType::CPU, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/slice.h"
namespace mace {
void Register_Slice(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
SliceOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
SliceOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
SliceOp<DeviceType::OPENCL, half>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_SLICE_H_
#define MACE_OPS_SLICE_H_
#include "mace/core/operator.h"
#include "mace/kernels/slice.h"
namespace mace {
template <DeviceType D, typename T>
class SliceOp : public Operator<D, T> {
public:
SliceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
bool Run(StatsFuture *future) override {
MACE_CHECK(this->OutputSize() >= 2) << "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs();
MACE_CHECK((input->dim(3) % this->OutputSize()) == 0) << "Outputs do not split input equally.";
functor_(input, output_list, future);
return true;
}
private:
kernels::SliceFunctor<D, T> functor_;
private:
OP_INPUT_TAGS(INPUT);
};
} // namespace mace
#endif // MACE_OPS_SLICE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template<DeviceType D, typename T>
static void BMSliceHelper(int iters,
const std::vector<index_t> &input_shape,
const index_t num_outputs) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
const index_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<index_t>());
std::vector<float> input_data(input_size);
GenerateRandomRealTypeData(input_shape, input_data);
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest");
builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i));
}
builder
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
auto builder = OpDefBuilder("Slice", "SliceTest");
builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i));
}
builder.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 2; ++i) {
net.RunOp(D);
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
#define BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
static void BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSliceHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
} \
BENCHMARK(BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
#define BM_SLICE(N, H, W, C, NO) \
BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \
BM_SLICE_MACRO(N, H, W, C, NO, float, OPENCL); \
BM_SLICE_MACRO(N, H, W, C, NO, half, OPENCL);
BM_SLICE(1, 32, 32, 32, 2);
BM_SLICE(1, 32, 32, 128, 2);
BM_SLICE(1, 32, 32, 256, 2);
BM_SLICE(1, 128, 128, 32, 2);
BM_SLICE(1, 128, 128, 128, 2);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/slice.h"
#include "mace/ops/ops_test_util.h"
#include "gmock/gmock.h"
using namespace mace;
class SliceOpTest : public OpsTestBase {};
template<DeviceType D, typename T>
void RandomTest(const int num_outputs) {
srand(time(nullptr));
const index_t output_channels = 4 * (1 + rand() % 10);
const index_t input_channels = num_outputs * output_channels;
const index_t batch = 3 + (rand() % 10);
const index_t height = 13 + (rand() % 10);
const index_t width = 17 + (rand() % 10);
// Construct graph
OpsTestNet net;
std::vector<index_t> input_shape({batch, height, width, input_channels});
const index_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<index_t>());
std::vector<float> input_data(input_size);
GenerateRandomRealTypeData(input_shape, input_data);
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest");
builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i));
}
builder
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
auto builder = OpDefBuilder("Slice", "SliceTest");
builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i));
}
builder.Finalize(net.NewOperatorDef());
}
// Run
net.RunOp(D);
if (D == DeviceType::OPENCL) {
for (int i = 0; i < num_outputs; ++i) {
ImageToBuffer<D, float>(net, MakeString("OutputImage", i), MakeString("Output", i),
kernels::BufferType::IN_OUT_CHANNEL);
}
}
// Check
std::vector<index_t> expected_shape({batch, height, width, output_channels});
const index_t outer_size = std::accumulate(expected_shape.begin(), expected_shape.end() - 1,
1, std::multiplies<index_t>());
const float *input_ptr = input_data.data();
const float *output_ptr;
for (int i = 0; i < num_outputs; ++i) {
auto output = net.GetOutput(MakeString("Output", i).c_str());
EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape));
Tensor::MappingGuard output_mapper(output);
output_ptr = output->data<float>();
for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const int idx = outer_idx * input_channels + i * output_channels;
for (int j = 0; j < output_channels; ++j) {
ASSERT_NEAR(*output_ptr++, input_ptr[idx + j], 1e-2) << "with output " << i << " index " << idx + j;
}
}
}
}
TEST_F(SliceOpTest, CPU) {
RandomTest<DeviceType::CPU, float>(2);
RandomTest<DeviceType::CPU, float>(4);
RandomTest<DeviceType::CPU, float>(11);
}
TEST_F(SliceOpTest, OPENCLFloat) {
RandomTest<DeviceType::OPENCL, float>(2);
RandomTest<DeviceType::OPENCL, float>(4);
RandomTest<DeviceType::OPENCL, float>(11);
}
TEST_F(SliceOpTest, OPENCLHalf) {
RandomTest<DeviceType::OPENCL, half>(2);
RandomTest<DeviceType::OPENCL, half>(4);
RandomTest<DeviceType::OPENCL, half>(11);
}
...@@ -174,9 +174,8 @@ class OperatorDef { ...@@ -174,9 +174,8 @@ class OperatorDef {
const std::string &type() const; const std::string &type() const;
void set_type(const std::string &type_); void set_type(const std::string &type_);
bool has_type() const; bool has_type() const;
int mem_id() const; const std::vector<int> &mem_id() const;
void set_mem_id(const int mem_id); void set_mem_id(const std::vector<int> &value);
bool has_mem_id() const;
uint32_t node_id() const; uint32_t node_id() const;
void set_node_id(uint32_t node_id); void set_node_id(uint32_t node_id);
uint32_t op_id() const; uint32_t op_id() const;
...@@ -220,7 +219,7 @@ class OperatorDef { ...@@ -220,7 +219,7 @@ class OperatorDef {
std::vector<OutputShape> output_shape_; std::vector<OutputShape> output_shape_;
std::vector<DataType> output_type_; std::vector<DataType> output_type_;
int mem_id_; std::vector<int> mem_id_;
// nnlib // nnlib
uint32_t node_id_; uint32_t node_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册