diff --git a/mace/core/operator.cc b/mace/core/operator.cc index fcf9069de8f2c21adb3589d4b66280464b068d90..dcbd9f7be8c2e252f0e16003a185cff2888fe499 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -83,7 +83,6 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry); extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry); extern void Register_ImageToBuffer(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); -extern void Register_Neg(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); @@ -122,7 +121,6 @@ OperatorRegistry::OperatorRegistry() { ops::Register_GlobalAvgPooling(this); ops::Register_ImageToBuffer(this); ops::Register_MatMul(this); - ops::Register_Neg(this); ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); diff --git a/mace/kernels/negative.h b/mace/kernels/negative.h deleted file mode 100644 index 544a1854c0d5c05a181fab61e81481cbab103690..0000000000000000000000000000000000000000 --- a/mace/kernels/negative.h +++ /dev/null @@ -1,69 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#ifndef MACE_KERNELS_NEGATIVE_H_ -#define MACE_KERNELS_NEGATIVE_H_ - -#include - -#include "mace/core/future.h" -#include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/tensor.h" -#include "mace/public/mace.h" - -namespace mace { -namespace kernels { - -template -struct NegFunctor { - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - const index_t batch = input->dim(0); - const index_t height = input->dim(1); - const index_t width = input->dim(2); - const index_t channels = input->dim(3); - - Tensor::MappingGuard input_mapper(input); - Tensor::MappingGuard output_mapper(output); - - const T *input_ptr = input->data(); - T *output_ptr = output->mutable_data(); - -#pragma omp parallel for collapse(4) - for (index_t n = 0; n < batch; ++n) { - for (index_t h = 0; h < height; ++h) { - for (index_t w = 0; w < width; ++w) { - for (index_t c = 0; c < channels; ++c) { - index_t pos = (((n * height) + h) * width + w) * channels + c; - output_ptr[pos] = 0 - input_ptr[pos]; - } - } - } - } - } -}; - -/* -template <> -void NegFunctor::operator()( - const Tensor *input, - const Tensor *bias, - Tensor *output, - StatsFuture *future); -*/ - -template -struct NegFunctor { - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future); - cl::Kernel kernel_; - std::vector input_shape_; -}; - -} // namespace kernels -} // namespace mace - -#endif // MACE_KERNELS_NEGATIVE_H_ diff --git a/mace/kernels/opencl/cl/neg.cl b/mace/kernels/opencl/cl/neg.cl deleted file mode 100644 index 7b539ddaf50cc0e09f07c7dde5f7ea9979158ea8..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/cl/neg.cl +++ /dev/null @@ -1,14 +0,0 @@ -#include -// Supported data types: half/float -__kernel void neg(__read_only image2d_t input, - __write_only image2d_t output) { - const int ch_blk = get_global_id(0); - const int w = get_global_id(1); - const int hb = get_global_id(2); - const int width = get_global_size(1); - - const int pos = mad24(ch_blk, width, w); - DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); - DATA_TYPE4 out = 0 - in; - WRITE_IMAGET(output, (int2)(pos, hb), out); -} diff --git a/mace/kernels/opencl/negative_opencl.cc b/mace/kernels/opencl/negative_opencl.cc deleted file mode 100644 index 70f866d89dc754e4e248d82a599d4c9bc9f8d1b6..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/negative_opencl.cc +++ /dev/null @@ -1,65 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/kernels/negative.h" -#include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/kernels/opencl/helper.h" -#include "mace/utils/utils.h" - -namespace mace { -namespace kernels { - -template -void NegFunctor::operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - const index_t batch = input->dim(0); - const index_t height = input->dim(1); - const index_t width = input->dim(2); - const index_t channels = input->dim(3); - - const index_t channel_blocks = RoundUpDiv4(channels); - - auto runtime = OpenCLRuntime::Global(); - if (kernel_.get() == nullptr) { - std::set built_options; - auto dt = DataTypeToEnum::value; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("neg"); - built_options.emplace("-Dneg=" + kernel_name); - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - kernel_ = runtime->BuildKernel("neg", kernel_name, built_options); - } - if (!IsVecEqual(input_shape_, input->shape())) { - uint32_t idx = 0; - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, *(output->opencl_image())); - input_shape_ = input->shape(); - } - - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8}; - - cl::Event event; - cl_int error = runtime->command_queue().enqueueNDRangeKernel( - kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); - MACE_CHECK(error == CL_SUCCESS); - if (future != nullptr) { - future->wait_fn = [runtime, event](CallStats *stats) { - event.wait(); - if (stats != nullptr) { - runtime->GetCallStats(event, stats); - } - }; - } -} - -template struct NegFunctor; -template struct NegFunctor; -} // namespace kernels -} // namespace mace diff --git a/mace/ops/neg.cc b/mace/ops/neg.cc deleted file mode 100644 index c4dffe708c3be7fa3e0d0f77278b977bfba01bd0..0000000000000000000000000000000000000000 --- a/mace/ops/neg.cc +++ /dev/null @@ -1,31 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/ops/neg.h" - -namespace mace { -namespace ops { - -void Register_Neg(OperatorRegistry *op_registry) { - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - NegOp); - - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") - .Device(DeviceType::OPENCL) - .TypeConstraint("T") - .Build(), - NegOp); - - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") - .Device(DeviceType::OPENCL) - .TypeConstraint("T") - .Build(), - NegOp); -} - -} // namespace ops -} // namespace mace diff --git a/mace/ops/neg.h b/mace/ops/neg.h deleted file mode 100644 index 0e3be04cb0b73b803d32b71deb64857f737d8485..0000000000000000000000000000000000000000 --- a/mace/ops/neg.h +++ /dev/null @@ -1,39 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#ifndef MACE_OPS_NEG_H_ -#define MACE_OPS_NEG_H_ - -#include - -#include "mace/core/operator.h" -#include "mace/kernels/negative.h" - -namespace mace { -namespace ops { - -template -class NegOp : public Operator { - public: - NegOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), - functor_() {} - - bool Run(StatsFuture *future) override { - const Tensor *input_tensor = this->Input(0); - Tensor *output_tensor = this->outputs_[0]; - output_tensor->ResizeLike(input_tensor); - - functor_(input_tensor, output_tensor, future); - return true; - } - - private: - kernels::NegFunctor functor_; -}; - -} // namespace ops -} // namespace mace - -#endif // MACE_OPS_NEGATIVE_H_ diff --git a/mace/ops/neg_benchmark.cc b/mace/ops/neg_benchmark.cc deleted file mode 100644 index 4d8f6cce16bf2f6ee9ee98651d6ecda3908f05a0..0000000000000000000000000000000000000000 --- a/mace/ops/neg_benchmark.cc +++ /dev/null @@ -1,82 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/operator.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/core/testing/test_benchmark.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -template -static void Neg(int iters, int batch, int channels, int height, int width) { - mace::testing::StopTiming(); - - OpsTestNet net; - - // Add input data - net.AddRandomInput("Input", {batch, height, width, channels}); - - if (D == DeviceType::OPENCL) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("Neg", "NegBM") - .Input("InputImage") - .Output("Output") - .Finalize(net.NewOperatorDef()); - } else { - OpDefBuilder("Neg", "NegBM") - .Input("Input") - .Output("Output") - .Finalize(net.NewOperatorDef()); - } - - // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); - } - net.Sync(); - - mace::testing::StartTiming(); - while (iters--) { - net.RunOp(D); - } - net.Sync(); -} - -#define BM_NEG_MACRO(N, C, H, W, TYPE, DEVICE) \ - static void BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::MaccProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - Neg(iters, N, C, H, W); \ - } \ - BENCHMARK(BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) - -#define BM_NEG(N, C, H, W) \ - BM_NEG_MACRO(N, C, H, W, float, CPU); \ - BM_NEG_MACRO(N, C, H, W, float, OPENCL); \ - BM_NEG_MACRO(N, C, H, W, half, OPENCL); - -BM_NEG(1, 1, 512, 512); -BM_NEG(1, 3, 128, 128); -BM_NEG(1, 3, 512, 512); -BM_NEG(1, 32, 112, 112); -BM_NEG(1, 64, 256, 256); -BM_NEG(1, 64, 512, 512); -BM_NEG(1, 128, 56, 56); -BM_NEG(1, 128, 256, 256); -BM_NEG(1, 256, 14, 14); -BM_NEG(1, 512, 14, 14); -BM_NEG(1, 1024, 7, 7); -BM_NEG(32, 1, 256, 256); -BM_NEG(32, 3, 256, 256); - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/neg_test.cc b/mace/ops/neg_test.cc deleted file mode 100644 index c7ae15f65426f366f790013b7eed430a1df9bfea..0000000000000000000000000000000000000000 --- a/mace/ops/neg_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/operator.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -class NegOpTest : public OpsTestBase {}; - -template -void NegSimple() { - OpsTestNet net; - - // Add input data - net.AddInputFromArray("Input", {1, 6, 2, 1}, - {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); - - if (D == DeviceType::OPENCL) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("Neg", "NegTest") - .Input("InputImage") - .Output("OutputImage") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - - // Transfer output - ImageToBuffer(&net, "OutputImage", "Output", - kernels::BufferType::IN_OUT_CHANNEL); - } else { - OpDefBuilder("Neg", "NegTest") - .Input("Input") - .Output("Output") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - } - - // Check - auto expected = CreateTensor( - {1, 6, 2, 1}, - {-5, -5, -7, -7, -9, -9, -11, -11, -13, -13, -15, -15}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); -} - -TEST_F(NegOpTest, NegSimpleCPU) { NegSimple(); } - -TEST_F(NegOpTest, NegSimpleOPENCL) { - NegSimple(); -} - -} // namespace test -} // namespace ops -} // namespace mace