提交 ef1837f6 编写于 作者: L liuqi

Finish softmax op and validate new gcn model(with softmax)

上级 e5d3bb53
...@@ -75,6 +75,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry); ...@@ -75,6 +75,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Relu(OperatorRegistry *op_registry); extern void Register_Relu(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() { OperatorRegistry::OperatorRegistry() {
Register_AddN(this); Register_AddN(this);
...@@ -93,6 +94,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -93,6 +94,7 @@ OperatorRegistry::OperatorRegistry() {
Register_Relu(this); Register_Relu(this);
Register_ResizeBilinear(this); Register_ResizeBilinear(this);
Register_SpaceToBatchND(this); Register_SpaceToBatchND(this);
Register_Softmax(this);
} }
} // namespace mace } // namespace mace
#include <common.h>
__kernel void softmax(__read_only image2d_t input,
__private const int channels,
__private const int remain_channels,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
const int chan_blks = get_global_size(0) - 1;
const int width = get_global_size(1);
int pos = width_idx;
DATA_TYPE max_value = -FLT_MAX;
DATA_TYPE sum = 0.0;
DATA_TYPE4 data;
for (short i = 0; i < chan_blks; ++i) {
data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
max_value = max(max_value, data.x);
max_value = max(max_value, data.y);
max_value = max(max_value, data.z);
max_value = max(max_value, data.w);
pos += width;
}
data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
switch(remain_channels) {
case 0:
max_value = max(max_value, data.w);
case 1:
max_value = max(max_value, data.z);
case 2:
max_value = max(max_value, data.y);
case 3:
max_value = max(max_value, data.x);
}
pos = width_idx;
for (short i = 0; i < chan_blks; ++i) {
data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
data = native_exp(data - max_value);
sum += data.x;
sum += data.y;
sum += data.z;
sum += data.w;
pos += width;
}
data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
data -= max_value;
switch(remain_channels) {
case 0:
sum += native_exp(data.w);
case 1:
sum += native_exp(data.z);
case 2:
sum += native_exp(data.y);
case 3:
sum += native_exp(data.x);
}
pos = mad24(chan_blk_idx, width, width_idx);
data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
data -= max_value;
const int exceeded = mul24(chan_blk_idx, 4) - channels;
switch(exceeded) {
case 1:
data.z = native_exp(data.z) / sum;
case 2:
data.y = native_exp(data.y) / sum;
case 3:
data.x = native_exp(data.x) / sum;
break;
default:
data = native_exp(data) / sum;
}
WRITE_IMAGET(output, (int2)(pos, hb_idx), data);
}
...@@ -35,15 +35,15 @@ void ReluFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -35,15 +35,15 @@ void ReluFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
relu_kernel = runtime->BuildKernel("relu", "relu", built_options); relu_kernel = runtime->BuildKernel("relu", "relu", built_options);
uint32_t idx = 0; uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer()))); relu_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
relu_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer()))); relu_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
} else { } else {
relu_kernel = runtime->BuildKernel("relu", "relux", built_options); relu_kernel = runtime->BuildKernel("relu", "relux", built_options);
uint32_t idx = 0; uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer()))); relu_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
relu_kernel.setArg(idx++, max_limit_); relu_kernel.setArg(idx++, max_limit_);
relu_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer()))); relu_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
} }
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/softmax.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template<typename T>
void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future) {
const index_t batch = logits->dim(0);
const index_t height = logits->dim(1);
const index_t width = logits->dim(2);
const index_t channels = logits->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const int remain_channels = channel_blocks * 4 - channels;
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
cl::Kernel softmax_kernel = runtime->BuildKernel("softmax", "softmax", built_options);
uint32_t idx = 0;
softmax_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(logits->buffer())));
softmax_kernel.setArg(idx++, static_cast<int>(channels));
softmax_kernel.setArg(idx++, remain_channels);
softmax_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(softmax_kernel);
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
local_ws[1] = std::min<uint32_t>(width, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(height * batch, kwg_size / (local_ws[0] * local_ws[1]));
return {{4, 15, 8}, //SNPE size
{local_ws[0], local_ws[1], local_ws[2]},
{kwg_size / 16, 4, 4},
{kwg_size / 32, 4, 8},
{kwg_size / 32, 8, 4},
{kwg_size / 64, 8, 8},
{kwg_size / 64, 16, 4},
{kwg_size / 128, 8, 16},
{kwg_size / 128, 16, 8},
{kwg_size / 128, 32, 4},
{1, kwg_size / 32, 32},
{1, kwg_size / 64, 64},
{1, kwg_size / 128, 128},
{3, 15, 9},
{7, 15, 9},
{9, 7, 15},
{15, 7, 9},
{1, kwg_size, 1}};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
softmax_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "softmax_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}
template
struct SoftmaxFunctor<DeviceType::OPENCL, float>;
template
struct SoftmaxFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_SOFTMAX_H_
#define MACE_KERNELS_SOFTMAX_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct SoftmaxFunctor {
void operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future) {
Tensor::MappingGuard logits_guard(logits);
Tensor::MappingGuard output_guard(output);
const T *logits_ptr = logits->data<T>();
T *output_ptr = output->mutable_data<T>();
auto &logits_shape = logits->shape();
const index_t batch_size = std::accumulate(logits_shape.begin(), logits_shape.end()-1,
1, std::multiplies<index_t>());
const index_t num_classes = logits_shape.back();
#pragma omp parallel for
for (index_t i = 0; i < batch_size; ++i) {
T max_value = *logits_ptr;
for (index_t c = 1; c < num_classes; ++c) {
max_value = std::max(max_value, logits_ptr[c]);
}
// TODO: check overflow?
T sum = 0;
std::vector<T> exp_data(num_classes);
for (index_t c = 0; c < num_classes; ++c) {
exp_data[c] = ::exp((*logits_ptr - max_value));
sum += exp_data[c];
logits_ptr++;
}
for (index_t c = 0; c < num_classes; ++c) {
*output_ptr = exp_data[c] / sum;
output_ptr++;
}
}
}
};
template<typename T>
struct SoftmaxFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future);
};
} // namepsace kernels
} // namespace mace
#endif // MACE_KERNELS_SOFTMAX_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/softmax.h"
namespace mace {
void Register_Softmax(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
SoftmaxOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
SoftmaxOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
SoftmaxOp<DeviceType::OPENCL, half>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_SOFTMAX_H_
#define MACE_SOFTMAX_H_
#include "mace/core/operator.h"
#include "mace/kernels/softmax.h"
namespace mace {
template <DeviceType D, class T>
class SoftmaxOp : public Operator<D, T> {
public:
SoftmaxOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *logits= this->Input(LOGITS);
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(logits);
functor_(logits, output, future);
return true;
}
private:
kernels::SoftmaxFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(LOGITS);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_SOFTMAX_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void SoftmaxBenchmark(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("InputImage")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_SOFTMAX(N, C, H, W, TYPE) \
BM_SOFTMAX_MACRO(N, C, H, W, TYPE, CPU); \
BM_SOFTMAX_MACRO(N, C, H, W, TYPE, OPENCL);
BM_SOFTMAX(1, 1, 512, 512, float);
BM_SOFTMAX(1, 3, 128, 128, float);
BM_SOFTMAX(1, 3, 512, 512, float);
BM_SOFTMAX(1, 32, 112, 112, float);
BM_SOFTMAX(1, 64, 256, 256, float);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
class SoftmaxOpTest : public OpsTestBase {};
template <DeviceType D>
void Simple() {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {1, 1, 2, 4}, {1, 1, 1, 1, 1, 2, 3, 4});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
auto expected = CreateTensor<float>({1, 1, 2, 4}, {0.25, 0.25, 0.25, 0.25,
0.0320586, 0.08714432, 0.23688282, 0.64391426});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-7);
}
TEST_F(SoftmaxOpTest, CPUSimple) {
Simple<DeviceType::CPU>();
}
TEST_F(SoftmaxOpTest, OPENCLSimple) {
Simple<DeviceType::OPENCL>();
}
template <DeviceType D>
void Complex(const std::vector<index_t> &logits_shape) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", logits_shape);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run on gpu
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-5);
}
TEST_F(SoftmaxOpTest, OPENCLAligned) {
Complex<DeviceType::OPENCL>({1, 256, 256, 3});
Complex<DeviceType::OPENCL>({1, 128, 128, 16});
}
TEST_F(SoftmaxOpTest, OPENCLMulBatchAligned) {
Complex<DeviceType::OPENCL>({5, 64, 64, 3});
Complex<DeviceType::OPENCL>({8, 128, 128, 8});
}
TEST_F(SoftmaxOpTest, OPENCLUnAligned) {
Complex<DeviceType::OPENCL>({1, 113, 107, 13});
Complex<DeviceType::OPENCL>({5, 211, 107, 1});
}
} // namespace mace
...@@ -43,8 +43,13 @@ class TFConverter(object): ...@@ -43,8 +43,13 @@ class TFConverter(object):
self.dt = dt self.dt = dt
self.device = device self.device = device
self.tf_graph = {} self.tf_graph = {}
self.tf_parents = {}
self.resolved_ops = {} self.resolved_ops = {}
self.unused_tensor = set() self.unused_tensor = set()
self.ops = {}
for op in tf_ops:
self.ops[op.name] = op
for op in tf_ops: for op in tf_ops:
self.resolved_ops[op.name] = 0 self.resolved_ops[op.name] = 0
...@@ -53,6 +58,9 @@ class TFConverter(object): ...@@ -53,6 +58,9 @@ class TFConverter(object):
if input_name not in self.tf_graph: if input_name not in self.tf_graph:
self.tf_graph[input_name] = [] self.tf_graph[input_name] = []
self.tf_graph[input_name].append(op) self.tf_graph[input_name].append(op)
if op.name not in self.tf_parents:
self.tf_parents[op.name] = []
self.tf_parents[op.name].append(self.ops[input_name])
def add_buffer_to_image(self, input_name, input_type): def add_buffer_to_image(self, input_name, input_type):
output_name = input_name[:-2] + "_b2i" + input_name[-2:] output_name = input_name[:-2] + "_b2i" + input_name[-2:]
...@@ -481,6 +489,36 @@ class TFConverter(object): ...@@ -481,6 +489,36 @@ class TFConverter(object):
self.add_output_shape(final_op.outputs, op_def) self.add_output_shape(final_op.outputs, op_def)
self.net_def.op.extend([op_def]) self.net_def.op.extend([op_def])
def is_softmax(self, op):
return op.type == 'Softmax' and \
len(self.tf_parents[op.name]) == 1 and self.tf_parents[op.name][0].type == 'Reshape' and \
len(self.tf_graph[op.name]) == 1 and self.tf_graph[op.name][0].type == 'Reshape'
def convert_softmax(self, softmax_op):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
# deal with first Reshape op
parent_reshape_op = self.tf_parents[softmax_op.name][0]
op_def.input.extend([parent_reshape_op.inputs[0].name])
self.unused_tensor.add(get_input_tensor(parent_reshape_op, 1).name)
self.resolved_ops[parent_reshape_op.name] = 1
# deal with Softmax op
op_def.name = softmax_op.name
op_def.type = softmax_op.type
self.resolved_ops[softmax_op.name] = 1
# deal with last Reshape op
reshape_op = self.tf_graph[softmax_op.name][0]
self.unused_tensor.add(get_input_tensor(reshape_op, 1).name)
op_def.output.extend([output.name for output in reshape_op.outputs])
self.add_output_shape(reshape_op.outputs, op_def)
self.resolved_ops[reshape_op.name] = 1
def convert_normal_op(self, op): def convert_normal_op(self, op):
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
arg = op_def.arg.add() arg = op_def.arg.add()
...@@ -529,12 +567,13 @@ class TFConverter(object): ...@@ -529,12 +567,13 @@ class TFConverter(object):
self.convert_space_to_batch(op, False) self.convert_space_to_batch(op, False)
elif op.type == 'BatchToSpaceND': elif op.type == 'BatchToSpaceND':
self.convert_space_to_batch(op, True) self.convert_space_to_batch(op, True)
elif self.is_softmax(op):
self.convert_softmax(op)
elif op.type in ['Relu']: elif op.type in ['Relu']:
self.convert_normal_op(op) self.convert_normal_op(op)
else: else:
raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))
for op in self.tf_ops: for op in self.tf_ops:
if self.resolved_ops[op.name] == 1: if self.resolved_ops[op.name] == 1:
continue continue
......
TF_INPUT_NODE=input_node TF_INPUT_NODE=input_node
TF_OUTPUT_NODE=GCN/br_result_x/fcn_br TF_OUTPUT_NODE=softmax/Reshape_1
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册