diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 341493fcadfad04d099e183b426a0368df45d288..e759d89d6ec7d187ed50b688c84c6bd4dd8f1ddb 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -75,6 +75,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Relu(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); +extern void Register_Softmax(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_AddN(this); @@ -93,6 +94,7 @@ OperatorRegistry::OperatorRegistry() { Register_Relu(this); Register_ResizeBilinear(this); Register_SpaceToBatchND(this); + Register_Softmax(this); } } // namespace mace diff --git a/mace/kernels/opencl/cl/softmax.cl b/mace/kernels/opencl/cl/softmax.cl new file mode 100644 index 0000000000000000000000000000000000000000..0188d7679153c18b347e7691a3b30d9e350e6ef5 --- /dev/null +++ b/mace/kernels/opencl/cl/softmax.cl @@ -0,0 +1,77 @@ +#include + +__kernel void softmax(__read_only image2d_t input, + __private const int channels, + __private const int remain_channels, + __write_only image2d_t output) { + const int chan_blk_idx = get_global_id(0); + const int width_idx = get_global_id(1); + const int hb_idx = get_global_id(2); + const int chan_blks = get_global_size(0) - 1; + const int width = get_global_size(1); + + int pos = width_idx; + DATA_TYPE max_value = -FLT_MAX; + DATA_TYPE sum = 0.0; + DATA_TYPE4 data; + for (short i = 0; i < chan_blks; ++i) { + data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); + max_value = max(max_value, data.x); + max_value = max(max_value, data.y); + max_value = max(max_value, data.z); + max_value = max(max_value, data.w); + pos += width; + } + data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); + switch(remain_channels) { + case 0: + max_value = max(max_value, data.w); + case 1: + max_value = max(max_value, data.z); + case 2: + max_value = max(max_value, data.y); + case 3: + max_value = max(max_value, data.x); + } + + pos = width_idx; + for (short i = 0; i < chan_blks; ++i) { + data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); + data = native_exp(data - max_value); + sum += data.x; + sum += data.y; + sum += data.z; + sum += data.w; + pos += width; + } + data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); + data -= max_value; + switch(remain_channels) { + case 0: + sum += native_exp(data.w); + case 1: + sum += native_exp(data.z); + case 2: + sum += native_exp(data.y); + case 3: + sum += native_exp(data.x); + } + + pos = mad24(chan_blk_idx, width, width_idx); + data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); + data -= max_value; + const int exceeded = mul24(chan_blk_idx, 4) - channels; + switch(exceeded) { + case 1: + data.z = native_exp(data.z) / sum; + case 2: + data.y = native_exp(data.y) / sum; + case 3: + data.x = native_exp(data.x) / sum; + break; + default: + data = native_exp(data) / sum; + } + + WRITE_IMAGET(output, (int2)(pos, hb_idx), data); +} diff --git a/mace/kernels/opencl/relu_opencl.cc b/mace/kernels/opencl/relu_opencl.cc index c9d5a7f18a03f419c0a1686dfa98b070f6ff7597..7561b1fa5a59d9b5f351f0f5c1a4403d401b8c1c 100644 --- a/mace/kernels/opencl/relu_opencl.cc +++ b/mace/kernels/opencl/relu_opencl.cc @@ -35,15 +35,15 @@ void ReluFunctor::operator()(const Tensor *input, relu_kernel = runtime->BuildKernel("relu", "relu", built_options); uint32_t idx = 0; - relu_kernel.setArg(idx++, *(static_cast(input->buffer()))); - relu_kernel.setArg(idx++, *(static_cast(output->buffer()))); + relu_kernel.setArg(idx++, *(static_cast(input->buffer()))); + relu_kernel.setArg(idx++, *(static_cast(output->buffer()))); } else { relu_kernel = runtime->BuildKernel("relu", "relux", built_options); uint32_t idx = 0; - relu_kernel.setArg(idx++, *(static_cast(input->buffer()))); + relu_kernel.setArg(idx++, *(static_cast(input->buffer()))); relu_kernel.setArg(idx++, max_limit_); - relu_kernel.setArg(idx++, *(static_cast(output->buffer()))); + relu_kernel.setArg(idx++, *(static_cast(output->buffer()))); } const uint32_t gws[3] = {static_cast(channel_blocks), static_cast(width), diff --git a/mace/kernels/opencl/softmax_opencl.cc b/mace/kernels/opencl/softmax_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..147e53d54af4b155917b1b8bd065e7766bc324f0 --- /dev/null +++ b/mace/kernels/opencl/softmax_opencl.cc @@ -0,0 +1,106 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/softmax.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void SoftmaxFunctor::operator()(const Tensor *logits, + Tensor *output, + StatsFuture *future) { + const index_t batch = logits->dim(0); + const index_t height = logits->dim(1); + const index_t width = logits->dim(2); + const index_t channels = logits->dim(3); + const index_t channel_blocks = RoundUpDiv4(channels); + const int remain_channels = channel_blocks * 4 - channels; + + auto runtime = OpenCLRuntime::Global(); + + std::set built_options; + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + cl::Kernel softmax_kernel = runtime->BuildKernel("softmax", "softmax", built_options); + + uint32_t idx = 0; + softmax_kernel.setArg(idx++, *(static_cast(logits->buffer()))); + softmax_kernel.setArg(idx++, static_cast(channels)); + softmax_kernel.setArg(idx++, remain_channels); + softmax_kernel.setArg(idx++, *(static_cast(output->buffer()))); + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + const std::vector lws = {8, 16, 8}; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(softmax_kernel); + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(width, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); + return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, + {kwg_size / 16, 4, 4}, + {kwg_size / 32, 4, 8}, + {kwg_size / 32, 8, 4}, + {kwg_size / 64, 8, 8}, + {kwg_size / 64, 16, 4}, + {kwg_size / 128, 8, 16}, + {kwg_size / 128, 16, 8}, + {kwg_size / 128, 32, 4}, + {1, kwg_size / 32, 32}, + {1, kwg_size / 64, 64}, + {1, kwg_size / 128, 128}, + {3, 15, 9}, + {7, 15, 9}, + {9, 7, 15}, + {15, 7, 9}, + {1, kwg_size, 1}}; + }; + cl::Event event; + auto func = [&](const std::vector ¶ms) -> cl_int { + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + softmax_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(params[0], params[1], params[2]), + nullptr, &event); + + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + return error; + }; + std::stringstream ss; + ss << "softmax_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + OpenCLProfilingTimer timer(&event); + Tuner::Get()->template TuneOrRun(ss.str(), + lws, + params_generator, + func, + &timer); + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } +} + +template +struct SoftmaxFunctor; +template +struct SoftmaxFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/softmax.h b/mace/kernels/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..c686c60cec60098200e7da0db31472321dbe41ae --- /dev/null +++ b/mace/kernels/softmax.h @@ -0,0 +1,63 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_SOFTMAX_H_ +#define MACE_KERNELS_SOFTMAX_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/core/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct SoftmaxFunctor { + void operator()(const Tensor *logits, + Tensor *output, + StatsFuture *future) { + + Tensor::MappingGuard logits_guard(logits); + Tensor::MappingGuard output_guard(output); + const T *logits_ptr = logits->data(); + T *output_ptr = output->mutable_data(); + auto &logits_shape = logits->shape(); + const index_t batch_size = std::accumulate(logits_shape.begin(), logits_shape.end()-1, + 1, std::multiplies()); + const index_t num_classes = logits_shape.back(); +#pragma omp parallel for + for (index_t i = 0; i < batch_size; ++i) { + T max_value = *logits_ptr; + for (index_t c = 1; c < num_classes; ++c) { + max_value = std::max(max_value, logits_ptr[c]); + } + // TODO: check overflow? + T sum = 0; + std::vector exp_data(num_classes); + for (index_t c = 0; c < num_classes; ++c) { + exp_data[c] = ::exp((*logits_ptr - max_value)); + sum += exp_data[c]; + logits_ptr++; + } + for (index_t c = 0; c < num_classes; ++c) { + *output_ptr = exp_data[c] / sum; + output_ptr++; + } + } + } +}; + + +template +struct SoftmaxFunctor { + + void operator()(const Tensor *logits, + Tensor *output, + StatsFuture *future); +}; + +} // namepsace kernels +} // namespace mace + +#endif // MACE_KERNELS_SOFTMAX_H_ diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac48b3d85901c7b99ca3abb9c0d185e4e5da2349 --- /dev/null +++ b/mace/ops/softmax.cc @@ -0,0 +1,29 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/softmax.h" + +namespace mace { + +void Register_Softmax(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SoftmaxOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SoftmaxOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SoftmaxOp); +} + +} // namespace mace diff --git a/mace/ops/softmax.h b/mace/ops/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..26914c6932396b2ed28a6243df66bced6805ddf1 --- /dev/null +++ b/mace/ops/softmax.h @@ -0,0 +1,40 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_SOFTMAX_H_ +#define MACE_SOFTMAX_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/softmax.h" + +namespace mace { + +template +class SoftmaxOp : public Operator { + public: + SoftmaxOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *logits= this->Input(LOGITS); + + Tensor *output = this->Output(OUTPUT); + output->ResizeLike(logits); + + functor_(logits, output, future); + return true; + } + + private: + kernels::SoftmaxFunctor functor_; + + protected: + OP_INPUT_TAGS(LOGITS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_SOFTMAX_H_ diff --git a/mace/ops/softmax_benchmark.cc b/mace/ops/softmax_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..030af807ca9d186551752116763c3ee7598ab9e6 --- /dev/null +++ b/mace/ops/softmax_benchmark.cc @@ -0,0 +1,67 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void SoftmaxBenchmark( + int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + + OpDefBuilder("Softmax", "SoftmaxBM") + .Input("InputImage") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("Softmax", "SoftmaxBM") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + SoftmaxBenchmark(iters, N, C, H, W); \ + } \ + BENCHMARK(BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_SOFTMAX(N, C, H, W, TYPE) \ + BM_SOFTMAX_MACRO(N, C, H, W, TYPE, CPU); \ + BM_SOFTMAX_MACRO(N, C, H, W, TYPE, OPENCL); + +BM_SOFTMAX(1, 1, 512, 512, float); +BM_SOFTMAX(1, 3, 128, 128, float); +BM_SOFTMAX(1, 3, 512, 512, float); +BM_SOFTMAX(1, 32, 112, 112, float); +BM_SOFTMAX(1, 64, 256, 256, float); +} // namespace mace diff --git a/mace/ops/softmax_test.cc b/mace/ops/softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f321a6d86ec0c7b7db7633a7ad6d43ada0d916 --- /dev/null +++ b/mace/ops/softmax_test.cc @@ -0,0 +1,107 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { + +class SoftmaxOpTest : public OpsTestBase {}; + +template +void Simple() { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", {1, 1, 2, 4}, {1, 1, 1, 1, 1, 2, 3, 4}); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("InputImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "Output", + kernels::BufferType::IN_OUT); + } else { + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + } + + auto expected = CreateTensor({1, 1, 2, 4}, {0.25, 0.25, 0.25, 0.25, + 0.0320586, 0.08714432, 0.23688282, 0.64391426}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-7); +} + +TEST_F(SoftmaxOpTest, CPUSimple) { + Simple(); +} +TEST_F(SoftmaxOpTest, OPENCLSimple) { + Simple(); +} + +template +void Complex(const std::vector &logits_shape) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", logits_shape); + + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run on cpu + net.RunOp(); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("InputImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + + // Run on gpu + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT); + + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-5); +} + +TEST_F(SoftmaxOpTest, OPENCLAligned) { + Complex({1, 256, 256, 3}); + Complex({1, 128, 128, 16}); +} + +TEST_F(SoftmaxOpTest, OPENCLMulBatchAligned) { + Complex({5, 64, 64, 3}); + Complex({8, 128, 128, 8}); +} + +TEST_F(SoftmaxOpTest, OPENCLUnAligned) { + Complex({1, 113, 107, 13}); + Complex({5, 211, 107, 1}); +} + +} // namespace mace diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 9dee8f216c2ca55092a775123c45ab7385fab5e4..006d41fe8851637a731e15e22c0b9cc36e69be13 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -43,8 +43,13 @@ class TFConverter(object): self.dt = dt self.device = device self.tf_graph = {} + self.tf_parents = {} self.resolved_ops = {} self.unused_tensor = set() + self.ops = {} + + for op in tf_ops: + self.ops[op.name] = op for op in tf_ops: self.resolved_ops[op.name] = 0 @@ -53,6 +58,9 @@ class TFConverter(object): if input_name not in self.tf_graph: self.tf_graph[input_name] = [] self.tf_graph[input_name].append(op) + if op.name not in self.tf_parents: + self.tf_parents[op.name] = [] + self.tf_parents[op.name].append(self.ops[input_name]) def add_buffer_to_image(self, input_name, input_type): output_name = input_name[:-2] + "_b2i" + input_name[-2:] @@ -481,6 +489,36 @@ class TFConverter(object): self.add_output_shape(final_op.outputs, op_def) self.net_def.op.extend([op_def]) + def is_softmax(self, op): + return op.type == 'Softmax' and \ + len(self.tf_parents[op.name]) == 1 and self.tf_parents[op.name][0].type == 'Reshape' and \ + len(self.tf_graph[op.name]) == 1 and self.tf_graph[op.name][0].type == 'Reshape' + + def convert_softmax(self, softmax_op): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + + # deal with first Reshape op + parent_reshape_op = self.tf_parents[softmax_op.name][0] + op_def.input.extend([parent_reshape_op.inputs[0].name]) + self.unused_tensor.add(get_input_tensor(parent_reshape_op, 1).name) + self.resolved_ops[parent_reshape_op.name] = 1 + + # deal with Softmax op + op_def.name = softmax_op.name + op_def.type = softmax_op.type + self.resolved_ops[softmax_op.name] = 1 + + # deal with last Reshape op + reshape_op = self.tf_graph[softmax_op.name][0] + self.unused_tensor.add(get_input_tensor(reshape_op, 1).name) + + op_def.output.extend([output.name for output in reshape_op.outputs]) + self.add_output_shape(reshape_op.outputs, op_def) + self.resolved_ops[reshape_op.name] = 1 + def convert_normal_op(self, op): op_def = self.net_def.op.add() arg = op_def.arg.add() @@ -529,12 +567,13 @@ class TFConverter(object): self.convert_space_to_batch(op, False) elif op.type == 'BatchToSpaceND': self.convert_space_to_batch(op, True) + elif self.is_softmax(op): + self.convert_softmax(op) elif op.type in ['Relu']: self.convert_normal_op(op) else: raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) - for op in self.tf_ops: if self.resolved_ops[op.name] == 1: continue diff --git a/tools/side_gcn.config b/tools/side_gcn.config index d22d730bac70cce3f5c665b5c83c56334f1de319..c7e23e97c0b77fcdc359f347529b06409e77acf4 100644 --- a/tools/side_gcn.config +++ b/tools/side_gcn.config @@ -1,2 +1,2 @@ TF_INPUT_NODE=input_node -TF_OUTPUT_NODE=GCN/br_result_x/fcn_br \ No newline at end of file +TF_OUTPUT_NODE=softmax/Reshape_1 \ No newline at end of file