From 5085a24f85f548c263e365a922b6a4df76cb9f3f Mon Sep 17 00:00:00 2001 From: liutuo Date: Fri, 20 Apr 2018 15:33:32 +0800 Subject: [PATCH] eltwise support multi dimensions --- mace/kernels/cwise.h | 2 + mace/kernels/eltwise.h | 230 ++++++++++++++++----- mace/kernels/opencl/cl/cwise.cl | 9 +- mace/kernels/opencl/cl/eltwise.cl | 93 ++++++++- mace/kernels/opencl/eltwise_opencl.cc | 23 ++- mace/ops/eltwise.h | 51 +++-- mace/ops/eltwise_test.cc | 275 ++++++++++++++++++++++++-- mace/ops/ops_test_util.h | 23 ++- mace/python/tools/tf_converter_lib.py | 34 ++-- 9 files changed, 621 insertions(+), 119 deletions(-) diff --git a/mace/kernels/cwise.h b/mace/kernels/cwise.h index 997410a3..5006b2cb 100644 --- a/mace/kernels/cwise.h +++ b/mace/kernels/cwise.h @@ -35,6 +35,7 @@ enum CWiseType { DIV = 5, NEG = 6, ABS = 7, + SQR_DIFF = 8, }; struct CWiseFunctorBase { @@ -92,6 +93,7 @@ struct CWiseFunctor : CWiseFunctorBase { } break; case DIV: + MACE_CHECK(fabs(coeff_) > 1e-6, "cannot divided by 0."); #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = input_ptr[i] / coeff_; diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 94c37cdb..672e03dd 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -32,10 +32,15 @@ enum EltwiseType { MAX = 2, MIN = 3, SUB = 4, + DIV = 5, + NEG = 6, + ABS = 7, + SQR_DIFF = 8, }; struct EltwiseFunctorBase { - EltwiseFunctorBase(const EltwiseType type, const std::vector &coeff) + EltwiseFunctorBase(const EltwiseType type, + const std::vector &coeff) : type_(type), coeff_(coeff) {} EltwiseType type_; @@ -44,74 +49,211 @@ struct EltwiseFunctorBase { template struct EltwiseFunctor : EltwiseFunctorBase { - EltwiseFunctor(const EltwiseType type, const std::vector &coeff) + EltwiseFunctor(const EltwiseType type, + const std::vector &coeff) : EltwiseFunctorBase(type, coeff) {} void operator()(const Tensor *input0, const Tensor *input1, + const index_t start_axis, + const bool is_scaler, + const float value, + const bool swap, Tensor *output, StatsFuture *future) { - Tensor::MappingGuard input0_guard(input0); - Tensor::MappingGuard input1_guard(input1); - Tensor::MappingGuard output_guard(output); + if (is_scaler) { + Tensor::MappingGuard input0_guard(input0); + Tensor::MappingGuard output_guard(output); - const T *input0_ptr = input0->data(); - const T *input1_ptr = input1->data(); - T *output_ptr = output->mutable_data(); - const index_t size = input0->size(); - - switch (type_) { - case PROD: + const T *input0_ptr = input0->data(); + T *output_ptr = output->mutable_data(); + const index_t num = input0->size(); + switch (type_) { + case PROD: +#pragma omp parallel for + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = input0_ptr[i] * value; + } + break; + case SUM: + if (coeff_.empty()) { #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input0_ptr[i] * input1_ptr[i]; - } - break; - case SUM: - if (coeff_.empty()) { + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = input0_ptr[i] + value; + } + } else { + const float coeff_0 = swap ? coeff_[1] : coeff_[0]; + const float coeff_1 = swap ? coeff_[0] : coeff_[1]; #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input0_ptr[i] + input1_ptr[i]; + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = coeff_0 * input0_ptr[i] + + coeff_1 * value; + } } - } else { + break; + case MAX: #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = - coeff_[0] * input0_ptr[i] + coeff_[1] * input1_ptr[i]; + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = std::max(input0_ptr[i], value); } - } - break; - case MAX: + break; + case MIN: #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = std::max(input0_ptr[i], input1_ptr[i]); - } - break; - case MIN: + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = std::min(input0_ptr[i], value); + } + break; + case SUB: #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = std::min(input0_ptr[i], input1_ptr[i]); - } - break; - case SUB: + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = swap ? value - input0_ptr[i] : + input0_ptr[i] - value; + } + break; + case DIV: + if (!swap) { + MACE_CHECK(fabs(value) > 1e-6, "cannot divided by 0."); #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input0_ptr[i] - input1_ptr[i]; - } - break; - default: - LOG(FATAL) << "Eltwise op not support type " << type_; + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = input0_ptr[i] / value; + } + } else { +#pragma omp parallel for + for (index_t i = 0; i < num; ++i) { + MACE_CHECK(fabs(input0_ptr[i]) > 1e-6, "cannot divided by 0."); + output_ptr[i] = value / input0_ptr[i]; + } + } + break; + case SQR_DIFF: +#pragma omp parallel for + for (index_t i = 0; i < num; ++i) { + const float tmp = input0_ptr[i] - value; + output_ptr[i] = tmp * tmp; + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type_; + } + } else { + MACE_CHECK_NOTNULL(input0); + MACE_CHECK_NOTNULL(input1); + Tensor::MappingGuard input0_guard(input0); + Tensor::MappingGuard input1_guard(input1); + Tensor::MappingGuard output_guard(output); + + const T *input0_ptr = input0->data(); + const T *input1_ptr = input1->data(); + T *output_ptr = output->mutable_data(); + const index_t size0 = input0->size(); + const index_t size1 = input1->size(); + + const index_t num = size0 / size1; + switch (type_) { + case PROD: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j= 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + input0_ptr[i * size1 + j] * input1_ptr[j]; + } + } + break; + case SUM: + if (coeff_.empty()) { +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + input0_ptr[i * size1 + j] + input1_ptr[j]; + } + } + } else { + const float coeff_0 = swap ? coeff_[1] : coeff_[0]; + const float coeff_1 = swap ? coeff_[0] : coeff_[1]; +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + coeff_0 * input0_ptr[i * size1 + j] + + coeff_1 * input1_ptr[j]; + } + } + } + break; + case MAX: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + std::max(input0_ptr[i * size1 + j], input1_ptr[j]); + } + } + break; + case MIN: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + std::min(input0_ptr[i * size1 + j], input1_ptr[j]); + } + } + break; + case SUB: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = swap ? + input0_ptr[i * size1 + j] - input1_ptr[j] : + input1_ptr[j] - input0_ptr[i * size1 + j]; + } + } + break; + case DIV: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + if (!swap) { + MACE_CHECK(fabs(input1_ptr[j]) > 1e-6, "cannot divided by 0."); + output_ptr[i * size1 + j] = + input0_ptr[i * size1 + j] / input1_ptr[j]; + } else { + MACE_CHECK(fabs(input0_ptr[i * size1 + j]) > 1e-6, + "cannot divided by 0."); + output_ptr[i * size1 + j] = + input1_ptr[j] / input0_ptr[i * size1 + j]; + } + } + } + break; + case SQR_DIFF: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + const T tmp = input0_ptr[i * size1 + j] - input1_ptr[j]; + output_ptr[i * size1 + j] = tmp * tmp; + } + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type_; + } } } }; template struct EltwiseFunctor : EltwiseFunctorBase { - EltwiseFunctor(const EltwiseType type, const std::vector &coeff) + EltwiseFunctor(const EltwiseType type, + const std::vector &coeff) : EltwiseFunctorBase(type, coeff) {} void operator()(const Tensor *input0, const Tensor *input1, + const index_t start_axis, + const bool is_scaler, + const float value, + const bool swap, Tensor *output, StatsFuture *future); diff --git a/mace/kernels/opencl/cl/cwise.cl b/mace/kernels/opencl/cl/cwise.cl index 2d3f3105..52c3a7b2 100644 --- a/mace/kernels/opencl/cl/cwise.cl +++ b/mace/kernels/opencl/cl/cwise.cl @@ -31,7 +31,14 @@ __kernel void cwise(KERNEL_ERROR_PARAMS #elif CWISE_TYPE == 4 out = in0 - in1; #elif CWISE_TYPE == 5 - out = in0 / in1; + if (fabs(in1.x) > 0.000001f) + out.x = in0.x / in1.x; + if (fabs(in1.y) > 0.000001f) + out.y = in0.y / in1.y; + if (fabs(in1.z) > 0.000001f) + out.z = in0.z / in1.z; + if (fabs(in1.w) > 0.000001f) + out.w = in0.w / in1.w; #elif CWISE_TYPE == 6 in1 = (DATA_TYPE4)(0, 0, 0, 0); out = in1 - in0; diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index 58838a7d..b2ebebec 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -1,30 +1,62 @@ #include __kernel void eltwise(KERNEL_ERROR_PARAMS - GLOBAL_WORK_GROUP_SIZE_DIM2 - __read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input0, __read_only image2d_t input1, + __private const float value, + __private const int height, + __private const int width, + __private const int channel, #ifdef COEFF_SUM __private const float coeff0, __private const float coeff1, #endif __write_only image2d_t output) { - const int w = get_global_id(0); - const int hb = get_global_id(1); + const int c = get_global_id(0); + const int w = get_global_id(1); + const int hb = get_global_id(2); #ifndef NON_UNIFORM_WORK_GROUP - if (w >= global_size_dim0 || hb >= global_size_dim1) return; + if (c >= global_size_dim0 || w >= global_size_dim1 || hb >= global_size_dim2) + return; #endif - DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb)); - DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb)); + int pos_w; + int pos_h; +#if START_AXIS == 0 + pos_w = mad24(c, width, w); + pos_h = hb; +#elif START_AXIS == 1 + pos_w = mad24(c, width, w); + pos_h = hb % height; +#elif START_AXIS == 2 + pos_w = mad24(c, width, w); + pos_h = 0; +#elif START_AXIS == 3 + pos_w = c; + pos_h = 0; +#endif + const int pos = mad24(c, width, w); + const int remain_channel = channel - 4 * c; + DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(pos, hb)); + DATA_TYPE4 in1 ; +#if IS_SCALER == 1 + in1 = (DATA_TYPE4){value, value, value, value}; +#else + in1 = READ_IMAGET(input1, SAMPLER, (int2)(pos_w, pos_h)); +#endif DATA_TYPE4 out; #if ELTWISE_TYPE == 0 out = in0 * in1; #elif ELTWISE_TYPE == 1 #ifdef COEFF_SUM - out = mad(coeff0, in0, mad(coeff1, in1, 0)); + #if NEEDSWAP == 0 + out = mad(coeff0, in0, mad(coeff1, in1, 0)); + #else + out = mad(coeff1, in0, mad(coeff0, in1, 0)); + #endif #else out = in0 + in1; #endif @@ -34,8 +66,49 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS #elif ELTWISE_TYPE == 3 out = fmin(in0, in1); #elif ELTWISE_TYPE == 4 - out = in0 - in1; + #if NEED_SWAP == 0 + out = in0 - in1; + #else + out = in1 - in0; + #endif +#elif ELTWISE_TYPE == 5 + #if NEED_SWAP == 0 + if (fabs(in1.x) > 0.000001f) + out.x = in0.x / in1.x; + if (fabs(in1.y) > 0.000001f) + out.y = in0.y / in1.y; + if (fabs(in1.z) > 0.000001f) + out.z = in0.z / in1.z; + if (fabs(in1.w) > 0.000001f) + out.w = in0.w / in1.w; + #else + if (fabs(in1.x) > 0.000001f) + out.x = in1.x / in0.x; + if (fabs(in1.y) > 0.000001f) + out.y = in1.y / in0.y; + if (fabs(in1.z) > 0.000001f) + out.z = in1.z / in0.z; + if (fabs(in1.w) > 0.000001f) + out.w = in1.w / in0.w; + #endif +#elif ELTWISE_TYPE == 8 + DATA_TYPE4 diff = in0 - in1; + out = diff * diff; +#endif + +#if ELTWISE_TYPE == 1 || ELTWISE_TYPE == 2 || ELTWISE_TYPE == 3 \ + || ELTWISE_TYPE == 4 || ELTWISE_TYPE == 8 + if (remain_channel < 4) { + switch (remain_channel) { + case 1: + out.y = 0; + case 2: + out.z = 0; + case 3: + out.w = 0; + } + } #endif - WRITE_IMAGET(output, (int2)(w, hb), out); + WRITE_IMAGET(output, (int2)(pos, hb), out); } diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index 629ba890..0ec4a1e5 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -23,6 +23,10 @@ namespace kernels { template void EltwiseFunctor::operator()(const Tensor *input0, const Tensor *input1, + const index_t start_axis, + const bool is_scaler, + const float value, + const bool swap, Tensor *output, StatsFuture *future) { const index_t batch = input0->dim(0); @@ -31,14 +35,15 @@ void EltwiseFunctor::operator()(const Tensor *input0, const index_t channels = input0->dim(3); const index_t channel_blocks = RoundUpDiv4(channels); - const index_t width_pixels = channel_blocks * width; const index_t batch_height_pixels = batch * height; - const uint32_t gws[2] = {static_cast(width_pixels), + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), static_cast(batch_height_pixels)}; + const int scaler = is_scaler ? 1 : 0; + const int need_swap = swap ? 1 : 0; auto runtime = OpenCLRuntime::Global(); - if (kernel_.get() == nullptr) { std::set built_options; auto dt = DataTypeToEnum::value; @@ -47,6 +52,9 @@ void EltwiseFunctor::operator()(const Tensor *input0, built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace(MakeString("-DELTWISE_TYPE=", type_)); + built_options.emplace(MakeString("-DSTART_AXIS=", start_axis)); + built_options.emplace(MakeString("-DIS_SCALER=", scaler)); + built_options.emplace(MakeString("-DNEEDSWAP=", need_swap)); if (runtime->IsOutOfRangeCheckEnabled()) { built_options.emplace("-DOUT_OF_RANGE_CHECK"); kernel_error_ = std::move(std::unique_ptr( @@ -73,9 +81,14 @@ void EltwiseFunctor::operator()(const Tensor *input0, if (!runtime->IsNonUniformWorkgroupsSupported()) { kernel_.setArg(idx++, gws[0]); kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); } kernel_.setArg(idx++, *(input0->opencl_image())); kernel_.setArg(idx++, *(input1->opencl_image())); + kernel_.setArg(idx++, value); + kernel_.setArg(idx++, static_cast(height)); + kernel_.setArg(idx++, static_cast(width)); + kernel_.setArg(idx++, static_cast(channels)); if (!coeff_.empty()) { kernel_.setArg(idx++, coeff_[0]); kernel_.setArg(idx++, coeff_[1]); @@ -85,11 +98,11 @@ void EltwiseFunctor::operator()(const Tensor *input0, input_shape_ = input0->shape(); } - const std::vector lws = {kwg_size_ / 16, 16, 0}; + const std::vector lws = {8, kwg_size_ / 64, 8, 0}; std::stringstream ss; ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); - TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); if (runtime->IsOutOfRangeCheckEnabled()) { kernel_error_->Map(nullptr); char *kerror_code = kernel_error_->mutable_data(); diff --git a/mace/ops/eltwise.h b/mace/ops/eltwise.h index 818fa5e5..2972a83a 100644 --- a/mace/ops/eltwise.h +++ b/mace/ops/eltwise.h @@ -32,24 +32,53 @@ class EltwiseOp : public Operator { OperatorBase::GetRepeatedArgument("coeff")) {} bool Run(StatsFuture *future) override { - const Tensor *input0 = this->Input(0); - const Tensor *input1 = this->Input(1); - Tensor *output = this->Output(OUTPUT); - MACE_CHECK(input0->dim_size() == input1->dim_size()) + if (this->InputSize() == 1) { + const Tensor* input = this->Input(0); + Tensor *output = this->Output(OUTPUT); + start_axis_ = input->dim_size() - 1; + is_scaler_ = true; + output->ResizeLike(input); + const float x = OperatorBase::GetSingleArgument("x", 1.0); + functor_(input, nullptr, start_axis_, + is_scaler_, x, false, output, future); + } else { + const index_t size0 = this->Input(0)->size(); + const index_t size1 = this->Input(1)->size(); + const bool swap = (size0 < size1); + const Tensor *input0 = swap ? this->Input(1) : this->Input(0); + const Tensor *input1 = swap ? this->Input(0) : this->Input(1); + + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input0->dim_size() == input1->dim_size()) << "Inputs of Eltwise op must be same shape"; - for (int i = 0; i < input0->dim_size(); ++i) { - MACE_CHECK(input0->dim(i) == input1->dim(i)) - << "Inputs of Eltwise op must be same shape"; + start_axis_ = input0->dim_size() - 1; + is_scaler_ = (input1->size() == 1); + uint32_t compared_size = 1; + if (!is_scaler_) { + while (start_axis_ >= 0) { + MACE_CHECK(input0->dim(start_axis_) == input1->dim(start_axis_), + "Invalid inputs dimension at axis: ") << start_axis_ + << "input 0: " << input0->dim(start_axis_) + << "input 1: " << input1->dim(start_axis_); + compared_size *= input1->dim(start_axis_); + if (compared_size == input1->size()) { + break; + } + start_axis_--; + } + } + output->ResizeLike(input0); + const float x = OperatorBase::GetSingleArgument("x", 1.0); + functor_(input0, input1, start_axis_, + is_scaler_, x, swap, output, future); } - - output->ResizeLike(input0); - - functor_(input0, input1, output, future); return true; } private: kernels::EltwiseFunctor functor_; + index_t start_axis_; + bool is_scaler_; private: OP_OUTPUT_TAGS(OUTPUT); diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc index ca24242b..6dd3b33d 100644 --- a/mace/ops/eltwise_test.cc +++ b/mace/ops/eltwise_test.cc @@ -25,23 +25,26 @@ class EltwiseOpTest : public OpsTestBase {}; namespace { template void Simple(const kernels::EltwiseType type, - const std::vector &shape, + const std::vector &shape0, + const std::vector &shape1, const std::vector &input0, const std::vector &input1, const std::vector &output, + const float x = 1.f, const std::vector coeff = {}) { // Construct graph OpsTestNet net; // Add input data - net.AddInputFromArray("Input1", shape, input0); - net.AddInputFromArray("Input2", shape, input1); + net.AddInputFromArray("Input1", shape0, input0); + net.AddInputFromArray("Input2", shape1, input1); if (D == DeviceType::CPU) { OpDefBuilder("Eltwise", "EltwiseTest") .Input("Input1") .Input("Input2") .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) .AddFloatsArg("coeff", coeff) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -57,6 +60,7 @@ void Simple(const kernels::EltwiseType type, .Input("InputImg1") .Input("InputImg2") .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) .AddFloatsArg("coeff", coeff) .Output("OutputImg") .Finalize(net.NewOperatorDef()); @@ -68,7 +72,7 @@ void Simple(const kernels::EltwiseType type, kernels::BufferType::IN_OUT_CHANNEL); } - auto expected = CreateTensor(shape, output); + auto expected = CreateTensor(shape0, output); ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } @@ -76,53 +80,200 @@ void Simple(const kernels::EltwiseType type, TEST_F(EltwiseOpTest, CPUSimple) { Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {3, 6, 9, 12, 15, 18}, {2, 1}); + {3, 6, 9, 12, 15, 18}, 1., {2, 1}); Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, {1, 2, 3, 4, 6, 6}); Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, {1, 1, 3, 3, 5, 6}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {0, 1, 0, 1, 1, 0}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 2, 10, 24}, + {1, 2, 1, 2, 0.5, 0.25}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {1, 4, 9, 4, 10, 18}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {2, 4, 6, 5, 7, 9}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {3, 6, 9, 9, 12, 15}, 1., {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 3, 4, 5, 6}); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 1, 3, 1, 1, 3}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {0, 1, 0, 9, 16, 9}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 1, 4, 5, 2}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {2, 4, 6, 8, 10, 12}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {3, 4, 5, 6, 7, 8}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {4, 6, 8, 10, 12, 14}, 2, {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {3, 3, 3, 4, 5, 6}, 3); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {1, 2, 3, 3, 3, 3}, 3); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {0.5}, + {2, 4, 6, 8, 10, 12}, 0.5); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {4, 1, 0, 1, 4, 9}, 3); } TEST_F(EltwiseOpTest, GPUSimple) { Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {1, 4, 9, 16, 25, 36}); + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, + {1, 4, 9, 16, 25, 36}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, + {2, 4, 6, 8, 10, 12}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, + {3, 6, 9, 12, 15, 18}, 1., {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {1, 2, 3, 4, 6, 6}); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {1, 1, 3, 3, 5, 6}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 2, 10, 24}, + {1, 2, 1, 2, 0.5, 0.25}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {0, 1, 0, 1, 1, 0}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {1, 4, 9, 4, 10, 18}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {2, 4, 6, 8, 10, 12}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {2, 4, 6, 5, 7, 9}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {3, 6, 9, 12, 15, 18}, {2, 1}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {3, 6, 9, 9, 12, 15}, 1., {2, 1}); Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, - {1, 2, 3, 4, 6, 6}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 3, 4, 5, 6}); Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, - {1, 1, 3, 3, 5, 6}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 1, 3, 1, 1, 3}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {0, 1, 0, 9, 16, 9}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 1, 4, 5, 2}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {2, 4, 6, 8, 10, 12}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {3, 4, 5, 6, 7, 8}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {4, 6, 8, 10, 12, 14}, 2, {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {3, 3, 3, 4, 5, 6}, 3); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {1, 2, 3, 3, 3, 3}, 3); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {4, 1, 0, 1, 4, 9}, 3); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {0.5}, + {2, 4, 6, 8, 10, 12}, 0.5); } namespace { template void RandomTest(const kernels::EltwiseType type, - const std::vector &shape) { + const std::vector &shape1, + const std::vector &shape2) { testing::internal::LogToStderr(); srand(time(NULL)); // Construct graph OpsTestNet net; + bool is_divide = (type == kernels::EltwiseType::DIV); + // Add input data - net.AddRandomInput("Input1", shape); - net.AddRandomInput("Input2", shape); + net.AddRandomInput("Input1", shape1, true, is_divide); + net.AddRandomInput("Input2", shape2, true, is_divide); + + OpDefBuilder("Eltwise", "EltwiseTest") .Input("Input1") @@ -166,24 +317,110 @@ void RandomTest(const kernels::EltwiseType type, TEST_F(EltwiseOpTest, OPENCLRandomFloat) { RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, {3, 23, 37, 19}); RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, {13, 32, 32, 64}); RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, {3, 32, 32, 64}); RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 37, 19}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 63}, + {1, 1, 32, 63}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 1, 19}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 1, 64}); } TEST_F(EltwiseOpTest, OPENCLRandomHalf) { RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, {3, 23, 37, 19}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 23, 37, 19}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 37, 19}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 1, 19}); RandomTest(kernels::EltwiseType::SUM, - {13, 32, 32, 64}); + {13, 32, 32, 64}, + {1, 1, 1, 1}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 32, 64}); RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, {3, 32, 32, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, + {1, 1, 32, 64}); RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {1, 1, 32, 64}); } } // namespace test diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 16243570..1439bf08 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -150,7 +150,8 @@ class OpsTestNet { template void AddRandomInput(const std::string &name, const std::vector &shape, - bool positive = true) { + bool positive = true, + bool truncate = false) { Tensor *input = ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); @@ -162,14 +163,24 @@ class OpsTestNet { std::normal_distribution nd(0, 1); if (DataTypeToEnum::value == DT_HALF) { std::generate( - input_data, input_data + input->size(), [&gen, &nd, positive] { - return half_float::half_cast(positive ? std::abs(nd(gen)) - : nd(gen)); + input_data, input_data + input->size(), + [&gen, &nd, positive, truncate] { + float d = nd(gen); + if (truncate) { + if (std::abs(d) > 100.f) d = 100.f; + if (std::abs(d) < 0.001f) d = 0.001f; + } + return half_float::half_cast(positive ?std::abs(d) : d); }); } else { std::generate(input_data, input_data + input->size(), - [&gen, &nd, positive] { - return positive ? std::abs(nd(gen)) : nd(gen); + [&gen, &nd, positive, truncate] { + float d = nd(gen); + if (truncate) { + if (std::abs(d) > 100.f) d = 100.f; + if (std::abs(d) < 0.001f) d = 0.001f; + } + return (positive ?std::abs(d) : d); }); } } diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 0f0ca20b..678324cb 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -835,31 +835,19 @@ class TFConverter(object): arg.name = 'T' arg.i = self.dt op_def.name = op.name - - if len(op.inputs) == 1: - op_def.type = "CWise" - op_def.input.extend([input.name for input in op.inputs]) - x_arg = op_def.arg.add() - x_arg.name = 'x' - x_arg.f = 0 - elif len(op.inputs) >= 2: + op_def.type = "Eltwise" + op_def.input.extend([input.name for input in op.inputs]) + x_value = op.get_attr('x') + if len(op.inputs) >= 2: input_tensor0 = get_input_tensor(op, 0) input_tensor1 = get_input_tensor(op, 1) - if input_tensor0.shape == input_tensor1.shape: - op_def.type = "Eltwise" - op_def.input.extend([input.name for input in op.inputs]) - else: - op_def.type = "CWise" - x_value = 0 - if len(input_tensor1.shape) == 4: - op_def.input.extend([op.inputs[1].name]) - x_value = get_input_tensor(op, 0).eval().astype(np.float32) - else: - op_def.input.extend([op.inputs[0].name]) - x_value = get_input_tensor(op, 1).eval().astype(np.float32) - x_arg = op_def.arg.add() - x_arg.name = 'x' - x_arg.f = x_value + if len(input_tensor0) == 1: + x_value = input_tensor0.eval().astype(np.float32) + elif len(input_tensor1) == 1: + x_value = input_tensor1.eval().astype(np.float32) + x_arg = op_def.arg.add() + x_arg.name = 'x' + x_arg.f = x_value type_arg = op_def.arg.add() type_arg.name = 'type' type_arg.i = math_type_mode[math_type] -- GitLab