From 6392835fa34259c7c74979de8fffcfd2b73d44f6 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Fri, 17 Aug 2018 14:39:57 +0800 Subject: [PATCH] format code, fix fill input type, update op list docs --- docs/user_guide/op_lists.rst | 6 ++--- mace/core/runtime/opencl/opencl_allocator.cc | 4 ++-- mace/core/types.cc | 2 +- mace/kernels/fill.h | 22 +++++++++---------- mace/kernels/opencl/helper.h | 2 +- mace/ops/concat_test.cc | 8 +++---- mace/ops/fill.h | 7 +++--- mace/ops/fill_test.cc | 8 ++++--- mace/ops/identity_test.cc | 2 +- mace/ops/reshape.h | 6 ++--- mace/ops/reshape_test.cc | 2 +- mace/ops/squeeze_test.cc | 2 +- .../converter_tool/tensorflow_converter.py | 4 ---- mace/test/mace_api_mt_test.cc | 2 +- mace/test/mace_api_test.cc | 2 +- 15 files changed, 38 insertions(+), 41 deletions(-) diff --git a/docs/user_guide/op_lists.rst b/docs/user_guide/op_lists.rst index 860c2f5f..63a03346 100644 --- a/docs/user_guide/op_lists.rst +++ b/docs/user_guide/op_lists.rst @@ -12,7 +12,7 @@ Operator lists "BIAS_ADD","Y","" "CAST","Y","Only CPU and TensorFlow model is supported." "CHANNEL_SHUFFLE","Y","" - "CONCATENATION","Y","Only support channel axis concatenation." + "CONCATENATION","Y","For GPU only support channel axis concatenation." "CONV_2D","Y","Fusion with BN and activation layer is supported." "CROP","Y","Only Caffe's crop layer is supported (in GPU, offset on channel-dim should be dividable by 4)." "DECONV_2D","Y","Supports Caffe's Deconvolution and TensorFlow's tf.layers.conv2d_transpose." @@ -20,7 +20,7 @@ Operator lists "DEPTH_TO_SPACE","Y","" "DEQUANTIZE","Y","Model quantization will be supported later." "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" - "EMBEDDING_LOOKUP","Y","Only support channel axis concatenation." + "EMBEDDING_LOOKUP","Y","" "FULLY_CONNECTED","Y","" "GROUP_CONV_2D","","Caffe model with group count = channel count is supported." "IDENTITY","Y","Only TensorFlow model is supported." @@ -44,7 +44,7 @@ Operator lists "SHAPE","Y","Only CPU and TensorFlow is supported." "STACK","Y","Only CPU and TensorFlow is supported." "STRIDEDSLICE","Y","Only CPU and TensorFlow is supported." - "SLICE","Y","In TensorFlow, this op is equivalent to SPLIT; Only support channel axis slice." + "SPLIT","Y","In Caffe, this op is equivalent to SLICE; For GPU only support channel axis slice." "SOFTMAX","Y","" "SPACE_TO_BATCH_ND", "Y","" "SPACE_TO_DEPTH","Y","" diff --git a/mace/core/runtime/opencl/opencl_allocator.cc b/mace/core/runtime/opencl/opencl_allocator.cc index 7dda80e6..86b0138d 100644 --- a/mace/core/runtime/opencl/opencl_allocator.cc +++ b/mace/core/runtime/opencl/opencl_allocator.cc @@ -70,7 +70,7 @@ MaceStatus OpenCLAllocator::New(size_t nbytes, void **result) const { MaceStatus OpenCLAllocator::NewImage(const std::vector &image_shape, const DataType dt, void **result) const { - MACE_CHECK(image_shape.size() == 2) << "Image shape's size must equal 2"; + MACE_CHECK(image_shape.size() == 2, "Image shape's size must equal 2"); VLOG(3) << "Allocate OpenCL image: " << image_shape[0] << ", " << image_shape[1]; @@ -134,7 +134,7 @@ void *OpenCLAllocator::Map(void *buffer, size_t offset, size_t nbytes) const { void *OpenCLAllocator::MapImage(void *buffer, const std::vector &image_shape, std::vector *mapped_image_pitch) const { - MACE_CHECK(image_shape.size() == 2) << "Just support map 2d image"; + MACE_CHECK(image_shape.size() == 2, "Just support map 2d image"); auto cl_image = static_cast(buffer); std::array origin = {0, 0, 0}; std::array region = {image_shape[0], image_shape[1], 1}; diff --git a/mace/core/types.cc b/mace/core/types.cc index 05b6acb3..8f29bcc0 100644 --- a/mace/core/types.cc +++ b/mace/core/types.cc @@ -39,7 +39,7 @@ std::string DataTypeToString(const DataType dt) { #endif {DT_UINT8, "DT_UINT8"}, {DT_INT32, "DT_UINT32"}}; - MACE_CHECK(dt != DT_INVALID) << "Not support Invalid data type"; + MACE_CHECK(dt != DT_INVALID, "Not support Invalid data type"); return dtype_string_map[dt]; } diff --git a/mace/kernels/fill.h b/mace/kernels/fill.h index 5e172c3f..b534a183 100644 --- a/mace/kernels/fill.h +++ b/mace/kernels/fill.h @@ -26,41 +26,39 @@ namespace mace { namespace kernels { -struct FillBase { - explicit FillBase(float value) : value_(value) {} - - int value_; -}; - template struct FillFunctor; template <> -struct FillFunctor : FillBase { - explicit FillFunctor(float value) : FillBase(value) {} +struct FillFunctor { + FillFunctor() {} MaceStatus operator()(const Tensor *shape, + const Tensor *value, Tensor *output, StatsFuture *future) { MACE_UNUSED(future); - MACE_CHECK(shape->dim_size() == 1) << "Shape must be 1-D"; + MACE_CHECK(shape->dim_size() == 1, "Shape must be 1-D"); const index_t num_dims = shape->dim(0); Tensor::MappingGuard shape_guard(shape); const int32_t *shape_data = shape->data(); std::vector output_shape; for (index_t i = 0; i < num_dims; ++i) { - MACE_CHECK(shape_data[i] > 0) << "Shape must be non-negative: " - << shape_data[i]; + MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ", + shape_data[i]); output_shape.push_back(shape_data[i]); } + Tensor::MappingGuard value_guard(value); + const float *value_data = value->data(); + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); Tensor::MappingGuard output_guard(output); float *output_data = output->mutable_data(); - std::fill(output_data, output_data + output->size(), value_); + std::fill(output_data, output_data + output->size(), *value_data); return MACE_SUCCESS; } diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index 22d9f1cc..5d4bf410 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -58,7 +58,7 @@ namespace kernels { if (runtime->IsOutOfRangeCheckEnabled()) { \ (kernel_error)->Map(nullptr); \ char *kerror_code = (kernel_error)->mutable_data(); \ - MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;\ + MACE_CHECK(*kerror_code == 0, "Kernel error code: ", *kerror_code);\ (kernel_error)->UnMap(); \ } diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index 9076aa27..f8b6b42a 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -55,10 +55,10 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { const float *output_ptr = output->data(); for (auto f : input0) { - ASSERT_EQ(f, *output_ptr++); + EXPECT_EQ(f, *output_ptr++); } for (auto f : input1) { - ASSERT_EQ(f, *output_ptr++); + EXPECT_EQ(f, *output_ptr++); } } @@ -93,10 +93,10 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { const float *output_ptr = output->data(); for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) { - ASSERT_EQ(input0[i * 4 + j], *output_ptr++); + EXPECT_EQ(input0[i * 4 + j], *output_ptr++); } for (int j = 0; j < 4; ++j) { - ASSERT_EQ(input1[i * 4 + j], *output_ptr++); + EXPECT_EQ(input1[i * 4 + j], *output_ptr++); } } } diff --git a/mace/ops/fill.h b/mace/ops/fill.h index 3e2c6df7..a8b55dbe 100644 --- a/mace/ops/fill.h +++ b/mace/ops/fill.h @@ -28,18 +28,19 @@ class FillOp : public Operator { public: FillOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), - functor_(OperatorBase::GetOptionalArg("value", 0.0f)) {} + functor_() {} MaceStatus Run(StatsFuture *future) override { const Tensor *shape = this->Input(SHAPE); + const Tensor *value = this->Input(VALUE); Tensor *output = this->Output(OUTPUT); - return functor_(shape, output, future); + return functor_(shape, value, output, future); } private: kernels::FillFunctor functor_; - MACE_OP_INPUT_TAGS(SHAPE); + MACE_OP_INPUT_TAGS(SHAPE, VALUE); MACE_OP_OUTPUT_TAGS(OUTPUT); }; diff --git a/mace/ops/fill_test.cc b/mace/ops/fill_test.cc index bc3a3363..1808b0b5 100644 --- a/mace/ops/fill_test.cc +++ b/mace/ops/fill_test.cc @@ -28,7 +28,7 @@ void TestFill(const std::vector &shape, OpsTestNet net; OpDefBuilder("Fill", "FillTest") .Input("Shape") - .AddFloatArg("value", static_cast(value)) + .Input("Value") .Output("Output") .Finalize(net.NewOperatorDef()); @@ -38,19 +38,21 @@ void TestFill(const std::vector &shape, {static_cast(shape.size())}, shape); + net.AddInputFromArray("Value", {}, {value}); + // Run net.RunOp(); auto output = net.GetTensor("Output"); for (index_t i = 0; i < output->dim_size(); ++i) { - ASSERT_EQ(output->dim(i), shape[i]); + EXPECT_EQ(output->dim(i), shape[i]); } const float *output_ptr = output->data(); const index_t size = output->size(); for (index_t i = 0; i < size; ++i) { - ASSERT_EQ(output_ptr[i], value); + EXPECT_EQ(output_ptr[i], value); } } } // namespace diff --git a/mace/ops/identity_test.cc b/mace/ops/identity_test.cc index 26d835ce..988ce760 100644 --- a/mace/ops/identity_test.cc +++ b/mace/ops/identity_test.cc @@ -46,7 +46,7 @@ void TestIdentity(const std::vector &shape) { const float *output_ptr = output->data(); const int size = output->size(); for (int i = 0; i < size; ++i) { - ASSERT_EQ(input_ptr[i], output_ptr[i]); + EXPECT_EQ(input_ptr[i], output_ptr[i]); } } } // namespace diff --git a/mace/ops/reshape.h b/mace/ops/reshape.h index 90a44314..c47e6cb1 100644 --- a/mace/ops/reshape.h +++ b/mace/ops/reshape.h @@ -42,12 +42,12 @@ class ReshapeOp : public Operator { for (int i = 0; i < num_dims; ++i) { if (shape_data[i] == -1) { - MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; + MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); unknown_idx = i; out_shape.push_back(1); } else { - MACE_CHECK(shape_data[i] >= 0) << "Shape must be non-negative: " - << shape_data[i]; + MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", + shape_data[i]); out_shape.push_back(shape_data[i]); product *= shape_data[i]; } diff --git a/mace/ops/reshape_test.cc b/mace/ops/reshape_test.cc index 91c0f82b..947e968b 100644 --- a/mace/ops/reshape_test.cc +++ b/mace/ops/reshape_test.cc @@ -53,7 +53,7 @@ void TestReshape(const std::vector &org_shape, const float *output_ptr = output->data(); const int size = output->size(); for (int i = 0; i < size; ++i) { - ASSERT_EQ(input_ptr[i], output_ptr[i]); + EXPECT_EQ(input_ptr[i], output_ptr[i]); } } } // namespace diff --git a/mace/ops/squeeze_test.cc b/mace/ops/squeeze_test.cc index 35f224c9..fba5a37d 100644 --- a/mace/ops/squeeze_test.cc +++ b/mace/ops/squeeze_test.cc @@ -49,7 +49,7 @@ void TestSqueeze(const std::vector &org_shape, const float *output_ptr = output->data(); const int size = output->size(); for (int i = 0; i < size; ++i) { - ASSERT_EQ(input_ptr[i], output_ptr[i]); + EXPECT_EQ(input_ptr[i], output_ptr[i]); } } } // namespace diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index be4678ed..9583d0e1 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -464,10 +464,6 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.Fill.name - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str - value_arg.f = tf_op.inputs[1].eval() - def convert_fused_batchnorm(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.FoldedBatchNorm.name diff --git a/mace/test/mace_api_mt_test.cc b/mace/test/mace_api_mt_test.cc index 27c601fe..e2a09fec 100644 --- a/mace/test/mace_api_mt_test.cc +++ b/mace/test/mace_api_mt_test.cc @@ -342,7 +342,7 @@ void MaceRunFunc(const int in_out_size) { MaceEngine engine(device); MaceStatus status = engine.Init(net_def.get(), input_names, output_names, reinterpret_cast(data.data())); - ASSERT_EQ(status, MaceStatus::MACE_SUCCESS); + EXPECT_EQ(status, MaceStatus::MACE_SUCCESS); std::map inputs; std::map outputs; diff --git a/mace/test/mace_api_test.cc b/mace/test/mace_api_test.cc index 46bd9fe1..6b1f353e 100644 --- a/mace/test/mace_api_test.cc +++ b/mace/test/mace_api_test.cc @@ -336,7 +336,7 @@ void MaceRun(const int in_out_size, MaceEngine engine(device); MaceStatus status = engine.Init(net_def.get(), input_names, output_names, reinterpret_cast(data.data())); - ASSERT_EQ(status, MaceStatus::MACE_SUCCESS); + EXPECT_EQ(status, MaceStatus::MACE_SUCCESS); std::map inputs; std::map outputs; -- GitLab