diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index 8017f40163e0dbc3bab955199b1b465e6f0ec65b..738f4448b85a9c2b9ba4d264655ea83e81822f2f 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -82,6 +82,8 @@ in one deployment file. - The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU. * - data_type - [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP. + * - input_data_types + - [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32. * - limit_opencl_kernel_time - [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0. * - obfuscate diff --git a/docs/user_guide/op_lists.rst b/docs/user_guide/op_lists.rst index 860c2f5fcaf35181e8caeb3f5ce5798c5b026a00..63a033467c7ac3856aecbff3ed7d9c4ceafeea8e 100644 --- a/docs/user_guide/op_lists.rst +++ b/docs/user_guide/op_lists.rst @@ -12,7 +12,7 @@ Operator lists "BIAS_ADD","Y","" "CAST","Y","Only CPU and TensorFlow model is supported." "CHANNEL_SHUFFLE","Y","" - "CONCATENATION","Y","Only support channel axis concatenation." + "CONCATENATION","Y","For GPU only support channel axis concatenation." "CONV_2D","Y","Fusion with BN and activation layer is supported." "CROP","Y","Only Caffe's crop layer is supported (in GPU, offset on channel-dim should be dividable by 4)." "DECONV_2D","Y","Supports Caffe's Deconvolution and TensorFlow's tf.layers.conv2d_transpose." @@ -20,7 +20,7 @@ Operator lists "DEPTH_TO_SPACE","Y","" "DEQUANTIZE","Y","Model quantization will be supported later." "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" - "EMBEDDING_LOOKUP","Y","Only support channel axis concatenation." + "EMBEDDING_LOOKUP","Y","" "FULLY_CONNECTED","Y","" "GROUP_CONV_2D","","Caffe model with group count = channel count is supported." "IDENTITY","Y","Only TensorFlow model is supported." @@ -44,7 +44,7 @@ Operator lists "SHAPE","Y","Only CPU and TensorFlow is supported." "STACK","Y","Only CPU and TensorFlow is supported." "STRIDEDSLICE","Y","Only CPU and TensorFlow is supported." - "SLICE","Y","In TensorFlow, this op is equivalent to SPLIT; Only support channel axis slice." + "SPLIT","Y","In Caffe, this op is equivalent to SLICE; For GPU only support channel axis slice." "SOFTMAX","Y","" "SPACE_TO_BATCH_ND", "Y","" "SPACE_TO_DEPTH","Y","" diff --git a/mace/core/runtime/opencl/opencl_allocator.cc b/mace/core/runtime/opencl/opencl_allocator.cc index 7dda80e62fd4ca66cf4f2e109e4ab0f5653ae05d..86b0138d727da41171c315fde3e121d88877fb04 100644 --- a/mace/core/runtime/opencl/opencl_allocator.cc +++ b/mace/core/runtime/opencl/opencl_allocator.cc @@ -70,7 +70,7 @@ MaceStatus OpenCLAllocator::New(size_t nbytes, void **result) const { MaceStatus OpenCLAllocator::NewImage(const std::vector &image_shape, const DataType dt, void **result) const { - MACE_CHECK(image_shape.size() == 2) << "Image shape's size must equal 2"; + MACE_CHECK(image_shape.size() == 2, "Image shape's size must equal 2"); VLOG(3) << "Allocate OpenCL image: " << image_shape[0] << ", " << image_shape[1]; @@ -134,7 +134,7 @@ void *OpenCLAllocator::Map(void *buffer, size_t offset, size_t nbytes) const { void *OpenCLAllocator::MapImage(void *buffer, const std::vector &image_shape, std::vector *mapped_image_pitch) const { - MACE_CHECK(image_shape.size() == 2) << "Just support map 2d image"; + MACE_CHECK(image_shape.size() == 2, "Just support map 2d image"); auto cl_image = static_cast(buffer); std::array origin = {0, 0, 0}; std::array region = {image_shape[0], image_shape[1], 1}; diff --git a/mace/core/types.cc b/mace/core/types.cc index 05b6acb3c5af3072c774f6620b6d4f9077a72b8b..8f29bcc0d61bc9c20f7fe68947ae9b56153f9333 100644 --- a/mace/core/types.cc +++ b/mace/core/types.cc @@ -39,7 +39,7 @@ std::string DataTypeToString(const DataType dt) { #endif {DT_UINT8, "DT_UINT8"}, {DT_INT32, "DT_UINT32"}}; - MACE_CHECK(dt != DT_INVALID) << "Not support Invalid data type"; + MACE_CHECK(dt != DT_INVALID, "Not support Invalid data type"); return dtype_string_map[dt]; } diff --git a/mace/kernels/fill.h b/mace/kernels/fill.h new file mode 100644 index 0000000000000000000000000000000000000000..b534a1839c77d183441e9cff74c1de6a917fa648 --- /dev/null +++ b/mace/kernels/fill.h @@ -0,0 +1,70 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_FILL_H_ +#define MACE_KERNELS_FILL_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct FillFunctor; + +template <> +struct FillFunctor { + FillFunctor() {} + + MaceStatus operator()(const Tensor *shape, + const Tensor *value, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + MACE_CHECK(shape->dim_size() == 1, "Shape must be 1-D"); + const index_t num_dims = shape->dim(0); + Tensor::MappingGuard shape_guard(shape); + const int32_t *shape_data = shape->data(); + + std::vector output_shape; + for (index_t i = 0; i < num_dims; ++i) { + MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ", + shape_data[i]); + output_shape.push_back(shape_data[i]); + } + + Tensor::MappingGuard value_guard(value); + const float *value_data = value->data(); + + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard output_guard(output); + float *output_data = output->mutable_data(); + + std::fill(output_data, output_data + output->size(), *value_data); + + return MACE_SUCCESS; + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_FILL_H_ diff --git a/mace/kernels/opencl/cl/slice.cl b/mace/kernels/opencl/cl/split.cl similarity index 95% rename from mace/kernels/opencl/cl/slice.cl rename to mace/kernels/opencl/cl/split.cl index f6b0c35a95249129a2d235557744ff9c522b7e84..8f93742ec552e294d29f19310989e96e62bf4d54 100644 --- a/mace/kernels/opencl/cl/slice.cl +++ b/mace/kernels/opencl/cl/split.cl @@ -1,6 +1,6 @@ #include -__kernel void slice(KERNEL_ERROR_PARAMS +__kernel void split(KERNEL_ERROR_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, __private const int chan_blk_offset, diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index 22d9f1cc548c8691b313db12f6693a86bdbf957b..5d4bf4104172ac093212fcb023941e9bb0015b6c 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -58,7 +58,7 @@ namespace kernels { if (runtime->IsOutOfRangeCheckEnabled()) { \ (kernel_error)->Map(nullptr); \ char *kerror_code = (kernel_error)->mutable_data(); \ - MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;\ + MACE_CHECK(*kerror_code == 0, "Kernel error code: ", *kerror_code);\ (kernel_error)->UnMap(); \ } diff --git a/mace/kernels/opencl/slice.cc b/mace/kernels/opencl/split.cc similarity index 90% rename from mace/kernels/opencl/slice.cc rename to mace/kernels/opencl/split.cc index b778e0d70aa16f3aa6141a6ed3198a6caec188cf..65fd6be530898200e50cc74518813cd01e7c9d15 100644 --- a/mace/kernels/opencl/slice.cc +++ b/mace/kernels/opencl/split.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/kernels/slice.h" +#include "mace/kernels/split.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" #include "mace/utils/tuner.h" @@ -21,7 +21,7 @@ namespace mace { namespace kernels { template -MaceStatus SliceFunctor::operator()( +MaceStatus SplitFunctor::operator()( const Tensor *input, const std::vector &output_list, StatsFuture *future) { @@ -29,7 +29,7 @@ MaceStatus SliceFunctor::operator()( const size_t outputs_count = output_list.size(); const index_t output_channels = input_channels / outputs_count; MACE_CHECK(output_channels % 4 == 0) - << "output channels of slice op must be divisible by 4"; + << "output channels of split op must be divisible by 4"; std::vector output_shape( {input->dim(0), input->dim(1), input->dim(2), output_channels}); @@ -46,12 +46,12 @@ MaceStatus SliceFunctor::operator()( std::set built_options; OUT_OF_RANGE_CONFIG(kernel_error_); NON_UNIFORM_WG_CONFIG; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice"); - built_options.emplace("-Dslice=" + kernel_name); + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("split"); + built_options.emplace("-Dsplit=" + kernel_name); built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum::value)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum::value)); - MACE_RETURN_IF_ERROR(runtime->BuildKernel("slice", + MACE_RETURN_IF_ERROR(runtime->BuildKernel("split", kernel_name, built_options, &kernel_)); @@ -116,8 +116,8 @@ MaceStatus SliceFunctor::operator()( return MACE_SUCCESS; } -template struct SliceFunctor; -template struct SliceFunctor; +template struct SplitFunctor; +template struct SplitFunctor; } // namespace kernels } // namespace mace diff --git a/mace/kernels/slice.h b/mace/kernels/split.h similarity index 89% rename from mace/kernels/slice.h rename to mace/kernels/split.h index 7ab311b01ed5cea1b40b59eef71a1bb6c704cbc6..95ff7861142e3f146f461328d04d1d21f2eb5a51 100644 --- a/mace/kernels/slice.h +++ b/mace/kernels/split.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_KERNELS_SLICE_H_ -#define MACE_KERNELS_SLICE_H_ +#ifndef MACE_KERNELS_SPLIT_H_ +#define MACE_KERNELS_SPLIT_H_ #include #include @@ -31,15 +31,15 @@ namespace mace { namespace kernels { -struct SliceFunctorBase { - explicit SliceFunctorBase(const int32_t axis) : axis_(axis) {} +struct SplitFunctorBase { + explicit SplitFunctorBase(const int32_t axis) : axis_(axis) {} int32_t axis_; }; template -struct SliceFunctor : SliceFunctorBase { - explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} +struct SplitFunctor : SplitFunctorBase { + explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {} MaceStatus operator()(const Tensor *input, const std::vector &output_list, @@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase { #ifdef MACE_ENABLE_OPENCL template -struct SliceFunctor : SliceFunctorBase { - explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} +struct SplitFunctor : SplitFunctorBase { + explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {} MaceStatus operator()(const Tensor *input, const std::vector &output_list, @@ -104,4 +104,4 @@ struct SliceFunctor : SliceFunctorBase { } // namespace kernels } // namespace mace -#endif // MACE_KERNELS_SLICE_H_ +#endif // MACE_KERNELS_SPLIT_H_ diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index eab4a4d5441ed36c6d8f779209127187ed7a6d5a..20c508aa73109fe5124d18804938a29d578720a0 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -169,7 +169,6 @@ struct StridedSliceFunctor { i += strides_data[0]) { *output_data++ = input_data[i]; } - } else if (input->dim_size() == 2) { for (index_t i = real_begin_indices[0]; strides_data[0] > 0 ? i < real_end_indices[0] @@ -179,7 +178,25 @@ struct StridedSliceFunctor { strides_data[1] > 0 ? j < real_end_indices[1] : j > real_end_indices[1]; j += strides_data[1]) { - *output_data++ = input_data[i * dim_stride[0] + j]; + *output_data++ = input_data[i * input->dim(1) + j]; + } + } + } else if (input->dim_size() == 3) { + for (index_t i = real_begin_indices[0]; + strides_data[0] > 0 ? i < real_end_indices[0] + : i > real_end_indices[0]; + i += strides_data[0]) { + for (index_t j = real_begin_indices[1]; + strides_data[1] > 0 ? j < real_end_indices[1] + : j > real_end_indices[1]; + j += strides_data[1]) { + for (index_t k = real_begin_indices[2]; + strides_data[2] > 0 ? k < real_end_indices[2] + : k > real_end_indices[2]; + k += strides_data[2]) { + *output_data++ = + input_data[(i * input->dim(1) + j) * input->dim(2) + k]; + } } } } else { diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index 9076aa2768157da4f42adf0ac39cba0fed4ba751..f8b6b42a7824d3ee2824ca60fffd585e0daf864c 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -55,10 +55,10 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { const float *output_ptr = output->data(); for (auto f : input0) { - ASSERT_EQ(f, *output_ptr++); + EXPECT_EQ(f, *output_ptr++); } for (auto f : input1) { - ASSERT_EQ(f, *output_ptr++); + EXPECT_EQ(f, *output_ptr++); } } @@ -93,10 +93,10 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { const float *output_ptr = output->data(); for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) { - ASSERT_EQ(input0[i * 4 + j], *output_ptr++); + EXPECT_EQ(input0[i * 4 + j], *output_ptr++); } for (int j = 0; j < 4; ++j) { - ASSERT_EQ(input1[i * 4 + j], *output_ptr++); + EXPECT_EQ(input1[i * 4 + j], *output_ptr++); } } } diff --git a/mace/ops/fill.cc b/mace/ops/fill.cc new file mode 100644 index 0000000000000000000000000000000000000000..93e6daddcf50c9db4b7dea2196a2e275e2620d18 --- /dev/null +++ b/mace/ops/fill.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/fill.h" + +namespace mace { +namespace ops { + +void Register_Fill(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Fill") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + FillOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/fill.h b/mace/ops/fill.h new file mode 100644 index 0000000000000000000000000000000000000000..a8b55dbe8984f2d6f87e39e1d39373e9ad909b58 --- /dev/null +++ b/mace/ops/fill.h @@ -0,0 +1,50 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_FILL_H_ +#define MACE_OPS_FILL_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/fill.h" + +namespace mace { +namespace ops { + +template +class FillOp : public Operator { + public: + FillOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_() {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *shape = this->Input(SHAPE); + const Tensor *value = this->Input(VALUE); + Tensor *output = this->Output(OUTPUT); + return functor_(shape, value, output, future); + } + + private: + kernels::FillFunctor functor_; + + MACE_OP_INPUT_TAGS(SHAPE, VALUE); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_FILL_H_ diff --git a/mace/ops/fill_test.cc b/mace/ops/fill_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1808b0b52bbbe2ab9ac46246b63a83477292895e --- /dev/null +++ b/mace/ops/fill_test.cc @@ -0,0 +1,67 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class FillTest : public OpsTestBase {}; + +namespace { +void TestFill(const std::vector &shape, + const float &value) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Fill", "FillTest") + .Input("Shape") + .Input("Value") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray( + "Shape", + {static_cast(shape.size())}, + shape); + + net.AddInputFromArray("Value", {}, {value}); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + + for (index_t i = 0; i < output->dim_size(); ++i) { + EXPECT_EQ(output->dim(i), shape[i]); + } + + const float *output_ptr = output->data(); + const index_t size = output->size(); + for (index_t i = 0; i < size; ++i) { + EXPECT_EQ(output_ptr[i], value); + } +} +} // namespace + +TEST_F(FillTest, Simple) { + TestFill({3, 2, 1}, 5.0f); + TestFill({1, 3}, -1.0f); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/identity_test.cc b/mace/ops/identity_test.cc index 26d835ce4d2260eb3f5aa95d57ab79f86523e357..988ce760c56d96a79f14520a857ce300e4869b00 100644 --- a/mace/ops/identity_test.cc +++ b/mace/ops/identity_test.cc @@ -46,7 +46,7 @@ void TestIdentity(const std::vector &shape) { const float *output_ptr = output->data(); const int size = output->size(); for (int i = 0; i < size; ++i) { - ASSERT_EQ(input_ptr[i], output_ptr[i]); + EXPECT_EQ(input_ptr[i], output_ptr[i]); } } } // namespace diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index 886546e332b2b830ad02b1df105db940c5fe84a2..3afe66c9c1c408993a18dfde19d8f1e63ba920fa 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); extern void Register_Dequantize(OperatorRegistryBase *op_registry); extern void Register_Eltwise(OperatorRegistryBase *op_registry); +extern void Register_Fill(OperatorRegistryBase *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry); extern void Register_Gather(OperatorRegistryBase *op_registry); @@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); -extern void Register_Slice(OperatorRegistryBase *op_registry); +extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry); extern void Register_Stack(OperatorRegistryBase *op_registry); extern void Register_StridedSlice(OperatorRegistryBase *op_registry); @@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_DepthwiseConv2d(this); ops::Register_Dequantize(this); ops::Register_Eltwise(this); + ops::Register_Fill(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); ops::Register_Gather(this); @@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); ops::Register_Shape(this); - ops::Register_Slice(this); + ops::Register_Split(this); ops::Register_Softmax(this); ops::Register_Stack(this); ops::Register_StridedSlice(this); diff --git a/mace/ops/reshape.h b/mace/ops/reshape.h index 90a443144bb87d32f8d99d722ef75554195772a8..c47e6cb1791e2fbd3e1fa1aa0506d9189f6dd0f1 100644 --- a/mace/ops/reshape.h +++ b/mace/ops/reshape.h @@ -42,12 +42,12 @@ class ReshapeOp : public Operator { for (int i = 0; i < num_dims; ++i) { if (shape_data[i] == -1) { - MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; + MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); unknown_idx = i; out_shape.push_back(1); } else { - MACE_CHECK(shape_data[i] >= 0) << "Shape must be non-negative: " - << shape_data[i]; + MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", + shape_data[i]); out_shape.push_back(shape_data[i]); product *= shape_data[i]; } diff --git a/mace/ops/reshape_test.cc b/mace/ops/reshape_test.cc index 91c0f82b7ae24c7da41ca5e504fa3aace600f29f..947e968b9dac4d7f163a635a56da14b619f883ce 100644 --- a/mace/ops/reshape_test.cc +++ b/mace/ops/reshape_test.cc @@ -53,7 +53,7 @@ void TestReshape(const std::vector &org_shape, const float *output_ptr = output->data(); const int size = output->size(); for (int i = 0; i < size; ++i) { - ASSERT_EQ(input_ptr[i], output_ptr[i]); + EXPECT_EQ(input_ptr[i], output_ptr[i]); } } } // namespace diff --git a/mace/ops/slice.cc b/mace/ops/split.cc similarity index 74% rename from mace/ops/slice.cc rename to mace/ops/split.cc index b6bf4b24e7fd6e974448e9751866503429cdea84..e5e103d7b1dfedb0e3ef26b9f2fbe0e84525ee6f 100644 --- a/mace/ops/slice.cc +++ b/mace/ops/split.cc @@ -12,30 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/ops/slice.h" +#include "mace/ops/split.h" namespace mace { namespace ops { -void Register_Slice(OperatorRegistryBase *op_registry) { - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") +void Register_Split(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split") .Device(DeviceType::CPU) .TypeConstraint("T") .Build(), - SliceOp); + SplitOp); #ifdef MACE_ENABLE_OPENCL - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split") .Device(DeviceType::GPU) .TypeConstraint("T") .Build(), - SliceOp); + SplitOp); - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split") .Device(DeviceType::GPU) .TypeConstraint("T") .Build(), - SliceOp); + SplitOp); #endif // MACE_ENABLE_OPENCL } diff --git a/mace/ops/slice.h b/mace/ops/split.h similarity index 78% rename from mace/ops/slice.h rename to mace/ops/split.h index 7f01162f67fb161f610f81ae04c4d6bf688400c7..710cdfb343de578c59830022b5e702e5ee99dd18 100644 --- a/mace/ops/slice.h +++ b/mace/ops/split.h @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_SLICE_H_ -#define MACE_OPS_SLICE_H_ +#ifndef MACE_OPS_SPLIT_H_ +#define MACE_OPS_SPLIT_H_ #include #include "mace/core/operator.h" -#include "mace/kernels/slice.h" +#include "mace/kernels/split.h" namespace mace { namespace ops { template -class SliceOp : public Operator { +class SplitOp : public Operator { public: - SliceOp(const OperatorDef &op_def, Workspace *ws) + SplitOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), functor_(OperatorBase::GetOptionalArg("axis", 3)) {} @@ -35,15 +35,15 @@ class SliceOp : public Operator { << "There must be at least two outputs for slicing"; const Tensor *input = this->Input(INPUT); const std::vector output_list = this->Outputs(); - const int32_t slice_axis = OperatorBase::GetOptionalArg("axis", 3); - MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0) + const int32_t split_axis = OperatorBase::GetOptionalArg("axis", 3); + MACE_CHECK((input->dim(split_axis) % this->OutputSize()) == 0) << "Outputs do not split input equally."; return functor_(input, output_list, future); } private: - kernels::SliceFunctor functor_; + kernels::SplitFunctor functor_; private: MACE_OP_INPUT_TAGS(INPUT); @@ -52,4 +52,4 @@ class SliceOp : public Operator { } // namespace ops } // namespace mace -#endif // MACE_OPS_SLICE_H_ +#endif // MACE_OPS_SPLIT_H_ diff --git a/mace/ops/slice_benchmark.cc b/mace/ops/split_benchmark.cc similarity index 78% rename from mace/ops/slice_benchmark.cc rename to mace/ops/split_benchmark.cc index c02dbf5c08b8aa35ac4946a04161a70dbbe69b18..8dea1263c8f1761b33b4dd63be7ef25915e4157b 100644 --- a/mace/ops/slice_benchmark.cc +++ b/mace/ops/split_benchmark.cc @@ -22,7 +22,7 @@ namespace test { namespace { template -void BMSliceHelper(int iters, +void BMSplitHelper(int iters, const std::vector &input_shape, const index_t num_outputs) { mace::testing::StopTiming(); @@ -42,7 +42,7 @@ void BMSliceHelper(int iters, BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); - auto builder = OpDefBuilder("Slice", "SliceTest"); + auto builder = OpDefBuilder("Split", "SplitTest"); builder.Input("InputImage"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("OutputImage", i)); @@ -51,7 +51,7 @@ void BMSliceHelper(int iters, .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { - auto builder = OpDefBuilder("Slice", "SliceTest"); + auto builder = OpDefBuilder("Split", "SplitTest"); builder.Input("Input"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("Output", i)); @@ -73,28 +73,28 @@ void BMSliceHelper(int iters, } } // namespace -#define MACE_BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ +#define MACE_BM_SPLIT_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ static void \ - MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \ + MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \ int iters) { \ const int64_t tot = static_cast(iters) * N * H * W * C; \ mace::testing::MaccProcessed(tot); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - BMSliceHelper(iters, {N, H, W, C}, NO); \ + BMSplitHelper(iters, {N, H, W, C}, NO); \ } \ MACE_BENCHMARK( \ - MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) + MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) -#define MACE_BM_SLICE(N, H, W, C, NO) \ - MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \ - MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, GPU); \ - MACE_BM_SLICE_MACRO(N, H, W, C, NO, half, GPU); +#define MACE_BM_SPLIT(N, H, W, C, NO) \ + MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, CPU); \ + MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, GPU); \ + MACE_BM_SPLIT_MACRO(N, H, W, C, NO, half, GPU); -MACE_BM_SLICE(1, 32, 32, 32, 2); -MACE_BM_SLICE(1, 32, 32, 128, 2); -MACE_BM_SLICE(1, 32, 32, 256, 2); -MACE_BM_SLICE(1, 128, 128, 32, 2); -MACE_BM_SLICE(1, 128, 128, 128, 2); +MACE_BM_SPLIT(1, 32, 32, 32, 2); +MACE_BM_SPLIT(1, 32, 32, 128, 2); +MACE_BM_SPLIT(1, 32, 32, 256, 2); +MACE_BM_SPLIT(1, 128, 128, 32, 2); +MACE_BM_SPLIT(1, 128, 128, 128, 2); } // namespace test } // namespace ops diff --git a/mace/ops/slice_test.cc b/mace/ops/split_test.cc similarity index 93% rename from mace/ops/slice_test.cc rename to mace/ops/split_test.cc index b445d56ab7438ce992070a5e83003f4536c55406..57544d18c62741434ea1162e179c3ff1856ab43a 100644 --- a/mace/ops/slice_test.cc +++ b/mace/ops/split_test.cc @@ -17,13 +17,13 @@ #include "gmock/gmock.h" #include "mace/ops/ops_test_util.h" -#include "mace/ops/slice.h" +#include "mace/ops/split.h" namespace mace { namespace ops { namespace test { -class SliceOpTest : public OpsTestBase {}; +class SplitOpTest : public OpsTestBase {}; namespace { template @@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) { BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); - auto builder = OpDefBuilder("Slice", "SliceTest"); + auto builder = OpDefBuilder("Split", "SplitTest"); builder.Input("InputImage"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("OutputImage", i)); @@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) { builder.AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { - auto builder = OpDefBuilder("Slice", "SliceTest").AddIntArg("axis", axis); + auto builder = OpDefBuilder("Split", "SplitTest").AddIntArg("axis", axis); builder.Input("Input"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("Output", i)); @@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) { } } // namespace -TEST_F(SliceOpTest, CPU) { +TEST_F(SplitOpTest, CPU) { RandomTest(2, 3); RandomTest(4, 3); RandomTest(11, 3); } -TEST_F(SliceOpTest, CPUAxis1) { +TEST_F(SplitOpTest, CPUAxis1) { RandomTest(2, 1); RandomTest(4, 1); RandomTest(11, 1); } -TEST_F(SliceOpTest, OPENCLFloat) { +TEST_F(SplitOpTest, OPENCLFloat) { RandomTest(2, 3); RandomTest(4, 3); RandomTest(11, 3); } -TEST_F(SliceOpTest, OPENCLHalf) { +TEST_F(SplitOpTest, OPENCLHalf) { RandomTest(2, 3); RandomTest(4, 3); RandomTest(11, 3); diff --git a/mace/ops/squeeze_test.cc b/mace/ops/squeeze_test.cc index 35f224c9a901dab81de1469c9218a0bb3b7debd8..fba5a37d245ea1c878753a96d39c2bf820af071e 100644 --- a/mace/ops/squeeze_test.cc +++ b/mace/ops/squeeze_test.cc @@ -49,7 +49,7 @@ void TestSqueeze(const std::vector &org_shape, const float *output_ptr = output->data(); const int size = output->size(); for (int i = 0; i < size; ++i) { - ASSERT_EQ(input_ptr[i], output_ptr[i]); + EXPECT_EQ(input_ptr[i], output_ptr[i]); } } } // namespace diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index 322f1135d14ce281ecc14baf10bb2eb102e9a8d6..d975d7beb922f40e648ffbdd009091537b5425c7 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { 0, 3, {}, {6}); } +TEST_F(StridedSliceOpTest, TestStridedSliceRank3) { + TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {0, 0, 0}, {2, 3, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, 2, 2}, + {1, 2, 5, 6, 7, 8, 11, 12}); + TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, + 6, 6}, {1, 0, 0}, {2, 1, 3}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, + 1, 3}, {3, 3, 3}); + TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, + 6, 6}, {0, 0, 0}, {2, 2, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, + 1, 2}, {1, 1, 3, 3}); +} + TEST_F(StridedSliceOpTest, TestSlice) { TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, {1, 2, 3, 4, 5, 6}); diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 33ef662cf91715aec3fef0ab838d62eb45fa7b1f..9a5440f440b9307f268956361d5f31e2eb3505c1 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -88,6 +88,7 @@ MaceSupportedOps = [ 'Dequantize', 'Eltwise', 'FoldedBatchNorm', + 'Fill', 'FullyConnected', 'Gather', 'Identity', @@ -101,6 +102,7 @@ MaceSupportedOps = [ 'Reshape', 'ResizeBilinear', 'Slice', + 'Split', 'Shape', 'Squeeze', 'Stack', @@ -146,6 +148,7 @@ class MaceKeyword(object): mace_constant_value_str = 'constant_value' mace_dims_str = 'dims' mace_axis_str = 'axis' + mace_num_split_str = 'num_split' mace_keepdims_str = 'keepdims' mace_shape_str = 'shape' mace_winograd_filter_transformed = 'is_filter_transformed' diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index fca6ca95b335f030cbf7ebfe92a376ed1512e2c7..9583d0e163be75b3a0e92afdfcb5a994bc59c5a6 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -68,6 +68,7 @@ TFSupportedOps = [ 'Relu6', 'Tanh', 'Sigmoid', + 'Fill', 'FusedBatchNorm', 'AvgPool', 'MaxPool', @@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Relu6.name: self.convert_activation, TFOpType.Tanh.name: self.convert_activation, TFOpType.Sigmoid.name: self.convert_activation, + TFOpType.Fill.name: self.convert_fill, TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm, TFOpType.AvgPool.name: self.convert_pooling, TFOpType.MaxPool.name: self.convert_pooling, @@ -458,6 +460,10 @@ class TensorflowConverter(base_converter.ConverterInterface): limit_arg.name = MaceKeyword.mace_activation_max_limit_str limit_arg.f = 6.0 + def convert_fill(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Fill.name + def convert_fused_batchnorm(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.FoldedBatchNorm.name @@ -763,19 +769,19 @@ class TensorflowConverter(base_converter.ConverterInterface): op.output_type.extend([mace_pb2.DT_INT32]) def convert_split(self, tf_op): - # inputs: [dim, input] axis = tf_op.inputs[0].eval().astype(np.int32) axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis - mace_check(axis == 3, 'Split with %d axis only support' % axis) input_shape = self.infer_tensor_shape(tf_op.inputs[1]) - mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0), - "The input's 4th dimension should be a multiple of 4") op = self.convert_general_op(tf_op) - op.type = MaceOp.Slice.name + op.type = MaceOp.Split.name del op.input[0] axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = axis + num_split_arg = op.arg.add() + num_split_arg.name = MaceKeyword.mace_num_split_str + num_split_arg.i = tf_op.get_attr('num_split') + self._skip_tensor.add(tf_op.inputs[0].name) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 1e68cb14a90ea40f3c7b0cb38832efc69b855bcb..16d9eae007d305e12276d9c0bcd66e0ba1e9f1d3 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface): "only support concat at " "channel dimension") arg.i = 3 + producer = self._producer[op.input[0]] input_shape = producer.output_shape[0].dims if producer.type == MaceOp.FullyConnected.name and \ diff --git a/mace/test/mace_api_mt_test.cc b/mace/test/mace_api_mt_test.cc index 27c601fe8410d57adef4a0179d70f14e8d8ade4e..e2a09fec8d3991fd8dad65b8427ae61ea35b8c3a 100644 --- a/mace/test/mace_api_mt_test.cc +++ b/mace/test/mace_api_mt_test.cc @@ -342,7 +342,7 @@ void MaceRunFunc(const int in_out_size) { MaceEngine engine(device); MaceStatus status = engine.Init(net_def.get(), input_names, output_names, reinterpret_cast(data.data())); - ASSERT_EQ(status, MaceStatus::MACE_SUCCESS); + EXPECT_EQ(status, MaceStatus::MACE_SUCCESS); std::map inputs; std::map outputs; diff --git a/mace/test/mace_api_test.cc b/mace/test/mace_api_test.cc index 46bd9fe1f9306325f3b82a35ab877c89e7af7162..6b1f353eb8f7a3d77e59b84f23fcf3141bfef148 100644 --- a/mace/test/mace_api_test.cc +++ b/mace/test/mace_api_test.cc @@ -336,7 +336,7 @@ void MaceRun(const int in_out_size, MaceEngine engine(device); MaceStatus status = engine.Init(net_def.get(), input_names, output_names, reinterpret_cast(data.data())); - ASSERT_EQ(status, MaceStatus::MACE_SUCCESS); + EXPECT_EQ(status, MaceStatus::MACE_SUCCESS); std::map inputs; std::map outputs; diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index 0da8838d59f525fd4b4778b68de0204275253a17..0d1b9cf0ca9e7e72d383e9cf593f95e1a60c66ae 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl")) - unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/slice.cl")) + unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl")) diff --git a/tools/converter.py b/tools/converter.py index b94aa80d325ad673dea08ab778c577cf659879b0..509d1eceed52ea185c996c06f6045b906c8bb45b 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -130,6 +130,16 @@ class RuntimeType(object): cpu_gpu = 'cpu+gpu' +InputDataTypeStrs = [ + "int32", + "float32", +] + +InputDataType = Enum('InputDataType', + [(ele, ele) for ele in InputDataTypeStrs], + type=str) + + CPUDataTypeStrs = [ "fp32", ] @@ -183,6 +193,7 @@ class YAMLKeyword(object): output_shapes = 'output_shapes' runtime = 'runtime' data_type = 'data_type' + input_data_types = 'input_data_types' limit_opencl_kernel_time = 'limit_opencl_kernel_time' nnlib_graph_mode = 'nnlib_graph_mode' obfuscate = 'obfuscate' @@ -447,6 +458,18 @@ def format_model_config(flags): if not isinstance(value, list): subgraph[key] = [value] + input_data_types = subgraph.get(YAMLKeyword.input_data_types, "") + if input_data_types: + if not isinstance(input_data_types, list): + subgraph[YAMLKeyword.input_data_types] = [input_data_types] + for input_data_type in input_data_types: + mace_check(input_data_type in InputDataTypeStrs, + ModuleName.YAML_CONFIG, + "'input_data_types' must be in " + + str(InputDataTypeStrs)) + else: + subgraph[YAMLKeyword.input_data_types] = [] + validation_threshold = subgraph.get( YAMLKeyword.validation_threshold, {}) if not isinstance(validation_threshold, dict): @@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types]) sh_commands.tuning_run( abi=target_abi, @@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types]) runtime_list = [] if target_abi == ABIType.host: @@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi, output_shapes=subgraphs[0][YAMLKeyword.output_shapes], model_output_dir=model_output_dir, phone_data_dir=PHONE_DATA_DIR, + input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa caffe_env=flags.caffe_env, validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa if flags.report and flags.round > 0: @@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types]) runtime_list = [] if target_abi == ABIType.host: runtime_list.extend([RuntimeType.cpu]) diff --git a/tools/generate_data.py b/tools/generate_data.py index d62297cc10423dcb16a4a405fae0c9e686fcd4e2..1e485f2034aeaad6e3d25ccfda24936cd827e880 100644 --- a/tools/generate_data.py +++ b/tools/generate_data.py @@ -27,30 +27,37 @@ import common # --input_ranges -1,1 -def generate_data(name, shape, input_file, tensor_range): +def generate_data(name, shape, input_file, tensor_range, input_data_type): np.random.seed() data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \ + tensor_range[0] input_file_name = common.formatted_file_name(input_file, name) print 'Generate input file: ', input_file_name - data.astype(np.float32).tofile(input_file_name) + if input_data_type == 'float32': + np_data_type = np.float32 + elif input_data_type == 'int32': + np_data_type = np.int32 + data.astype(np_data_type).tofile(input_file_name) -def generate_input_data(input_file, input_node, input_shape, input_ranges): +def generate_input_data(input_file, input_node, input_shape, input_ranges, + input_data_type): input_names = [name for name in input_node.split(',')] input_shapes = [shape for shape in input_shape.split(':')] if input_ranges: input_ranges = [r for r in input_ranges.split(':')] else: - input_ranges = None - assert len(input_names) == len(input_shapes) + input_ranges = [[-1, 1]] * len(input_names) + if input_data_type: + input_data_types = [data_type + for data_type in input_data_type.split(',')] + else: + input_data_types = ['float32'] * len(input_names) + assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa for i in range(len(input_names)): shape = [int(x) for x in input_shapes[i].split(',')] - if input_ranges: - input_range = [float(x) for x in input_ranges[i].split(',')] - else: - input_range = [-1, 1] - generate_data(input_names[i], shape, input_file, input_range) + generate_data(input_names[i], shape, input_file, input_ranges[i], + input_data_types[i]) print "Generate input file done." @@ -66,6 +73,8 @@ def parse_args(): "--input_shape", type=str, default="1,64,64,3", help="input shape.") parser.add_argument( "--input_ranges", type=str, default="-1,1", help="input range.") + parser.add_argument( + "--input_data_type", type=str, default="", help="input range.") return parser.parse_known_args() @@ -73,4 +82,4 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape, - FLAGS.input_ranges) + FLAGS.input_ranges, FLAGS.input_data_type) diff --git a/tools/sh_commands.py b/tools/sh_commands.py index d9b50342459db5a318dc17bea1786f287004a0d5..ebef1c4a750556a609240d01c200e9845606f595 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -536,6 +536,7 @@ def gen_random_input(model_output_dir, input_shapes, input_files, input_ranges, + input_data_types, input_file_name="model_input"): for input_name in input_nodes: formatted_name = common.formatted_file_name( @@ -545,10 +546,12 @@ def gen_random_input(model_output_dir, input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) input_ranges_str = ":".join(input_ranges) + input_data_types_str = ",".join(input_data_types) generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, input_shapes_str, - input_ranges_str) + input_ranges_str, + input_data_types_str) input_file_list = [] if isinstance(input_files, list): @@ -800,6 +803,7 @@ def validate_model(abi, output_shapes, model_output_dir, phone_data_dir, + input_data_types, caffe_env, input_file_name="model_input", output_file_name="model_out", @@ -821,7 +825,7 @@ def validate_model(abi, "%s/%s" % (model_output_dir, output_file_name), device_type, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), - validation_threshold) + validation_threshold, ",".join(input_data_types)) elif platform == "caffe": image_name = "mace-caffe:latest" container_name = "mace_caffe_validator" diff --git a/tools/validate.py b/tools/validate.py index 87bb3458e3ee0045645c6e5c9347bad549533594..516cf5128dddf1f1557bfdb8dd5db5a0364660bc 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -40,11 +40,13 @@ import common VALIDATION_MODULE = 'VALIDATION' -def load_data(file): +def load_data(file, data_type='float32'): if os.path.isfile(file): - return np.fromfile(file=file, dtype=np.float32) - else: - return np.empty([0]) + if data_type == 'float32': + return np.fromfile(file=file, dtype=np.float32) + elif data_type == 'int32': + return np.fromfile(file=file, dtype=np.int32) + return np.empty([0]) def compare_output(platform, device_type, output_name, mace_out_value, @@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name): def validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, - output_names, validation_threshold): + output_names, validation_threshold, input_data_types): import tensorflow as tf if not os.path.isfile(model_file): common.MaceLogger.error( @@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file, input_dict = {} for i in range(len(input_names)): input_value = load_data( - common.formatted_file_name(input_file, input_names[i])) + common.formatted_file_name(input_file, input_names[i]), + input_data_types[i]) input_value = input_value.reshape(input_shapes[i]) input_node = graph.get_tensor_by_name( normalize_tf_tensor_name(input_names[i])) @@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file, def validate(platform, model_file, weight_file, input_file, mace_out_file, device_type, input_shape, output_shape, input_node, output_node, - validation_threshold): + validation_threshold, input_data_type): input_names = [name for name in input_node.split(',')] input_shape_strs = [shape for shape in input_shape.split(':')] input_shapes = [[int(x) for x in shape.split(',')] for shape in input_shape_strs] + if input_data_type: + input_data_types = [data_type + for data_type in input_data_type.split(',')] + else: + input_data_types = ['float32'] * len(input_names) output_names = [name for name in output_node.split(',')] assert len(input_names) == len(input_shapes) if platform == 'tensorflow': validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, - output_names, validation_threshold) + output_names, validation_threshold, input_data_types) elif platform == 'caffe': output_shape_strs = [shape for shape in output_shape.split(':')] output_shapes = [[int(x) for x in shape.split(',')] @@ -220,6 +228,11 @@ def parse_args(): "--output_shape", type=str, default="1,64,64,2", help="output shape.") parser.add_argument( "--input_node", type=str, default="input_node", help="input node") + parser.add_argument( + "--input_data_type", + type=str, + default="", + help="input data type") parser.add_argument( "--output_node", type=str, default="output_node", help="output node") parser.add_argument( @@ -241,4 +254,5 @@ if __name__ == '__main__': FLAGS.output_shape, FLAGS.input_node, FLAGS.output_node, - FLAGS.validation_threshold) + FLAGS.validation_threshold, + FLAGS.input_data_type)