diff --git a/mace/core/net.cc b/mace/core/net.cc index e71003af8693c0bd4bd49166573fd7e364c88a66..fbe1c1b8b9da81929732a77c176195f29dd688b9 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -175,7 +175,6 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, // NHWC -> NCHW input_shape = TransposeShape(input_shape, {0, 3, 1, 2}); - input_data_format = DataFormat::NCHW; } } } diff --git a/mace/ops/crop.cc b/mace/ops/crop.cc index a9af8b8db88c05a6735339e868c4080a4ca1c389..3dda169dd80f02a258d854ce88c7f511beab0167 100644 --- a/mace/ops/crop.cc +++ b/mace/ops/crop.cc @@ -15,21 +15,34 @@ #include #include "mace/core/operator.h" +#include "mace/utils/math.h" +#include "mace/utils/memory.h" #ifdef MACE_ENABLE_OPENCL #include "mace/ops/opencl/image/crop.h" #endif // MACE_ENABLE_OPENCL -#include "mace/utils/memory.h" namespace mace { namespace ops { template -class CropOp : public Operation { +class CropOp; + +template +class CropOp : public Operation { public: explicit CropOp(OpConstructContext *context) : Operation(context), - axis_(Operation::GetOptionalArg("axis", 2)), - offset_(Operation::GetRepeatedArgs("offset")) {} + offset_(Operation::GetRepeatedArgs("offset")) { + MACE_CHECK(offset_.size() == 4, + "crop op only supports 4-dims inputs now."); + auto has_df = Operation::GetOptionalArg( + "has_data_format", 0); + if (has_df) { + // NHWC -> NCHW + offset_ = TransposeShape(offset_, {0, 3, 1, 2}); + } + } + MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); @@ -47,21 +60,13 @@ class CropOp : public Operation { std::vector output_shape(input0->shape()); for (index_t i = 0; i < in0_dims; ++i) { - int32_t crop_offset = 0; - index_t new_size = input0->dim(i); - if (i >= axis_) { - new_size = input1->dim(i); - if (offset_.size() == 1) { - crop_offset = offset_[0]; - } else if (offset_.size() > 1) { - crop_offset = offset_[i - axis_]; - } - MACE_CHECK(input0->dim(i) - crop_offset >= input1->dim(i)) - << "the crop for dimension" << i << "is out of bound with size" - << input1->dim(i) << "and offset" << crop_offset; + if (offset_[i] >= 0) { + output_shape[i] = input1->dim(i); + offsets[i] = offset_[i]; + MACE_CHECK(input0->dim(i) - offset_[i] >= input1->dim(i)) + << "the crop for dimension " << i << " is out of bound with size " + << input1->dim(i) << " and offset " << offsets[i]; } - output_shape[i] = new_size; - offsets[i] = crop_offset; } MACE_RETURN_IF_ERROR(output->Resize(output_shape)); T *output_data = output->mutable_data(); @@ -103,7 +108,6 @@ class CropOp : public Operation { } private: - const int axis_; std::vector offset_; }; @@ -113,10 +117,9 @@ class CropOp : public Operation { public: explicit CropOp(OpConstructContext *context) : Operation(context) { - const int axis = Operation::GetOptionalArg("axis", 2); if (context->device()->gpu_runtime()->UseImageMemory()) { kernel_ = make_unique>( - axis, Operation::GetRepeatedArgs("offset")); + Operation::GetRepeatedArgs("offset")); } else { MACE_NOT_IMPLEMENTED; } diff --git a/mace/ops/crop_benchmark.cc b/mace/ops/crop_benchmark.cc index 4ca25b15a3cd607e9b8394bc090e502486cc93e7..724d8ca2958360e991031b003af59f4a3f27b183 100644 --- a/mace/ops/crop_benchmark.cc +++ b/mace/ops/crop_benchmark.cc @@ -21,107 +21,80 @@ namespace test { namespace { template -void CropHelper(int iters, int crop_axis, int dim1, int offset) { +void CropHelper(int iters, + const std::vector &shape0, + const std::vector &shape1, + int crop_axis, + int offset) { mace::testing::StopTiming(); OpsTestNet net; - OpDefBuilder("Crop", "CropBM") - .Input("Input0") - .Input("Input1") - .AddIntArg("axis", crop_axis) - .AddIntsArg("offset", {offset}) - .Output("Output") - .Finalize(net.NewOperatorDef()); - // Add input data - const int kDim0 = 100; - net.AddRandomInput("Input0", {1, kDim0, dim1, dim1, }); - net.AddRandomInput("Input1", - {1, kDim0 / 2, dim1 / 2, dim1 / 2}); + std::vector offsets(4, -1); - // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); + for (int i = crop_axis; i < 4; ++i) { + offsets[i] = offset; } - const int64_t tot = static_cast(iters) * kDim0 * dim1 * dim1; - testing::BytesProcessed(tot * sizeof(T)); - mace::testing::StartTiming(); - while (iters--) { - net.RunOp(D); - } -} -} // namespace - -#define MACE_BM_CROP_CPU_MACRO(AXIS, DIM, OFFSET) \ - static void MACE_BM_CROP_CPU_##AXIS##_##DIM##_##OFFSET(int iters) { \ - CropHelper(iters, AXIS, DIM, OFFSET); \ - } \ - MACE_BENCHMARK(MACE_BM_CROP_CPU_##AXIS##_##DIM##_##OFFSET) - -MACE_BM_CROP_CPU_MACRO(1, 256, 3); -MACE_BM_CROP_CPU_MACRO(2, 256, 3); -MACE_BM_CROP_CPU_MACRO(3, 512, 3); -MACE_BM_CROP_CPU_MACRO(2, 512, 6); - -namespace { -template -void OpenCLCropHelper(int iters, - const std::vector &shape0, - const std::vector &shape1, - int crop_axis, - int offset) { - mace::testing::StopTiming(); - - OpsTestNet net; - // Add input data - net.AddRandomInput("Input0", shape0); - net.AddRandomInput("Input1", shape1); + if (D == DeviceType::CPU) { + auto input_shape0 = TransposeShape(shape0, {0, 3, 1, 2}); + auto input_shape1 = TransposeShape(shape1, {0, 3, 1, 2}); + net.AddRandomInput("Input0", input_shape0); + net.AddRandomInput("Input1", input_shape1); + } else if (D == DeviceType::GPU) { + // Add input data + net.AddRandomInput("Input0", shape0); + net.AddRandomInput("Input1", shape1); + } else { + MACE_NOT_IMPLEMENTED; + } OpDefBuilder("Crop", "CropBM") .Input("Input0") .Input("Input1") - .AddIntArg("axis", crop_axis) - .AddIntsArg("offset", {offset}) + .AddIntsArg("offset", offsets) + .AddIntArg("has_data_format", 1) .Output("Output") .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(DeviceType::GPU); + net.Setup(D); + for (int i = 0; i < 1; ++i) { + net.Run(); } const int64_t tot = static_cast(iters) * - (net.GetTensor("Input0")->size() + net.GetTensor("Input1")->size()); + (net.GetTensor("Input0")->size()); testing::BytesProcessed(tot * sizeof(T)); mace::testing::StartTiming(); while (iters--) { - net.RunOp(DeviceType::GPU); + net.Run(); } } } // namespace -#define MACE_BM_CROP_GPU_MACRO(N, H, W, C, AXIS, OFFSET, TYPE) \ - static void MACE_BM_CROP_GPU_##N##_##H##_##W##_##C##_##AXIS##_##OFFSET##\ - _##TYPE(int iters) { \ - std::vector shape0 = {N, H, W, C}; \ - std::vector shape1 = {N / 2, H / 2, W / 2, C / 2}; \ - OpenCLCropHelper(iters, shape0, shape1, AXIS, OFFSET); \ - } \ - MACE_BENCHMARK(MACE_BM_CROP_GPU_##N##_##H##_##W##_##C##_##AXIS##_##OFFSET\ - ##_##TYPE) - -MACE_BM_CROP_GPU_MACRO(4, 32, 32, 32, 2, 4, float); -MACE_BM_CROP_GPU_MACRO(8, 32, 32, 64, 1, 0, float); -MACE_BM_CROP_GPU_MACRO(8, 32, 32, 128, 0, 0, float); -MACE_BM_CROP_GPU_MACRO(8, 32, 32, 256, 2, 4, float); +#define MACE_BM_CROP_MACRO(N, H, W, C, AXIS, OFFSET, DEVICE, TYPE) \ + static void MACE_BM_CROP_##N##_##H##_##W##_##C##_##AXIS##_##OFFSET## \ + _##DEVICE##_##TYPE(int iters) { \ + std::vector shape0 = {N, H, W, C}; \ + std::vector shape1 = {N / 2, H / 2, W / 2, C / 2}; \ + CropHelper(iters, shape0, shape1, AXIS, OFFSET); \ + } \ + MACE_BENCHMARK(MACE_BM_CROP_##N##_##H##_##W##_##C##_##AXIS##_##OFFSET\ + ##_##DEVICE##_##TYPE) + +#define MACE_BM_CROP(N, H, W, C, AXIS, OFFSET) \ + MACE_BM_CROP_MACRO(N, H, W, C, AXIS, OFFSET, CPU, float); \ + MACE_BM_CROP_MACRO(N, H, W, C, AXIS, OFFSET, GPU, float); \ + MACE_BM_CROP_MACRO(N, H, W, C, AXIS, OFFSET, GPU, half); + +MACE_BM_CROP(4, 32, 32, 32, 2, 4); +MACE_BM_CROP(8, 32, 32, 64, 1, 0); +MACE_BM_CROP(8, 32, 32, 128, 0, 0); +MACE_BM_CROP(8, 32, 32, 256, 2, 4); -MACE_BM_CROP_GPU_MACRO(4, 32, 32, 32, 2, 4, half); -MACE_BM_CROP_GPU_MACRO(8, 32, 32, 64, 1, 0, half); -MACE_BM_CROP_GPU_MACRO(8, 32, 32, 128, 0, 0, half); -MACE_BM_CROP_GPU_MACRO(8, 32, 32, 256, 2, 4, half); } // namespace test } // namespace ops diff --git a/mace/ops/crop_test.cc b/mace/ops/crop_test.cc index 872d3154491e22a63ed6e98621a63476ea70ebb5..213b8ce89a58b5745c4e5685c6a825442b5826ce 100644 --- a/mace/ops/crop_test.cc +++ b/mace/ops/crop_test.cc @@ -26,7 +26,6 @@ void RunCrop(const std::vector &input_shape, const std::vector &input_data, const std::vector &input_shape2, const std::vector &offset, - const int axis, const std::vector &expected_shape, const std::vector &expected_data) { OpsTestNet net; @@ -39,7 +38,7 @@ void RunCrop(const std::vector &input_shape, .Input("Input1") .Output("Output") .AddIntsArg("offset", offset) - .AddIntArg("axis", axis) + .AddIntArg("has_data_format", 1) .Finalize(net.NewOperatorDef()); } else if (D == CPU) { net.TransformDataFormat("Input0", @@ -55,7 +54,7 @@ void RunCrop(const std::vector &input_shape, .Input("InputNCHW1") .Output("OutputNCHW") .AddIntsArg("offset", offset) - .AddIntArg("axis", axis) + .AddIntArg("has_data_format", 1) .Finalize(net.NewOperatorDef()); } @@ -113,7 +112,7 @@ TEST_F(CropTest, SimpleCPU) { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, - 4.0, 4.0, 4.0}, {1, 5, 5, 3}, {2, 2}, 2, + 4.0, 4.0, 4.0}, {1, 5, 5, 3}, {-1, 2, 2, -1}, {1, 5, 5, 3}, {1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, @@ -168,7 +167,7 @@ TEST_F(CropTest, SimpleGPU) { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, - 4.0, 4.0, 4.0}, {1, 5, 5, 3}, {2, 2}, 2, + 4.0, 4.0, 4.0}, {1, 5, 5, 3}, {-1, 2, 2, -1}, {1, 5, 5, 3}, {1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, diff --git a/mace/ops/opencl/image/crop.h b/mace/ops/opencl/image/crop.h index 3ffb4fba69a8a79f46d188fbe9ddd9a2540759f1..e390a6ca69a2712dc1959c75ece199255011a173 100644 --- a/mace/ops/opencl/image/crop.h +++ b/mace/ops/opencl/image/crop.h @@ -34,16 +34,14 @@ template class CropKernel : public OpenCLCropKernel { public: explicit CropKernel( - const int axis, const std::vector &offset) - : axis_(axis), offset_(offset) {} + : offset_(offset) {} MaceStatus Compute( OpContext *context, const std::vector &input_list, Tensor *output) override; private: - const int axis_; std::vector offset_; cl::Kernel kernel_; uint32_t kwg_size_; @@ -68,57 +66,14 @@ MaceStatus CropKernel::Compute( std::vector offsets(4, 0); std::vector output_shape(input0->shape()); - switch (axis_) { - case 0: - if (offset_.size() == 1) { - offsets[0] = offset_[0]; - offsets[1] = offset_[0]; - offsets[2] = offset_[0]; - offsets[3] = offset_[0]; - } else if (offset_.size() == 4) { - offsets[0] = offset_[0]; - offsets[1] = offset_[2]; - offsets[2] = offset_[3]; - offsets[3] = offset_[1]; - } - for (int i = 0; i < 4; ++i) { - output_shape[i] = input1->dim(i); - } - break; - case 1: - if (offset_.size() == 1) { - offsets[1] = offset_[0]; - offsets[2] = offset_[0]; - offsets[3] = offset_[0]; - } else if (offset_.size() == 3) { - offsets[1] = offset_[1]; - offsets[2] = offset_[2]; - offsets[3] = offset_[0]; - } - for (int i = 1; i < 4; ++i) { - output_shape[i] = input1->dim(i); - } - break; - case 2: - if (offset_.size() == 1) { - offsets[1] = offset_[0]; - offsets[2] = offset_[0]; - } else if (offset_.size() == 2) { - offsets[1] = offset_[0]; - offsets[2] = offset_[1]; - } - output_shape[1] = input1->dim(1); - output_shape[2] = input1->dim(2); - break; - case 3: - if (offset_.size() == 1) { - offsets[2] = offset_[0]; - } - output_shape[2] = input1->dim(2); - break; - default: - MACE_CHECK(axis_ >= 0 && axis_ < 4, "axis is out of boundary."); - break; + for (index_t i = 0; i < in0_dims; ++i) { + if (offset_[i] >= 0) { + output_shape[i] = input1->dim(i); + offsets[i] = offset_[i]; + MACE_CHECK(input0->dim(i) - offset_[i] >= input1->dim(i)) + << "the crop for dimension " << i << " is out of bound with size " + << input1->dim(i) << " and offset " << offsets[i]; + } } MACE_CHECK(offsets[3] % 4 == 0, "MACE opencl only supports cropping channel" diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index cdabaea4c2c88848f2f451f482b20bc9d3ba8295..c5b6176824d28dcf67a4dd68defdebdfecafcbed 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -552,18 +552,20 @@ class CaffeConverter(base_converter.ConverterInterface): param = caffe_op.layer.crop_param op.type = MaceOp.Crop.name - axis_arg = op.arg.add() - axis_arg.name = MaceKeyword.mace_axis_str - axis_arg.i = 2 - if param.HasField(MaceKeyword.mace_axis_str): - axis_arg.i = param.axis - axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i + axis = param.axis + axis = 4 + axis if axis < 0 else axis + offset_value = -1 * np.ones(4, dtype=np.int32) + offset_len = len(param.offset) + if offset_len == 1: + while axis < 4: + offset_value[axis] = param.offset[0] + axis += 1 + else: + offset_value[axis:] = param.offset + offset_arg = op.arg.add() offset_arg.name = MaceKeyword.mace_offset_str - if len(param.offset) > 0: - offset_arg.ints.extend(list(param.offset)) - else: - offset_arg.i = 0 + offset_arg.ints.extend(offset_value) def convert_concat(self, caffe_op): op = self.convert_general_op(caffe_op) diff --git a/mace/python/tools/converter_tool/shape_inference.py b/mace/python/tools/converter_tool/shape_inference.py index 2ca45425d37c27ee88106a15ab8637bae91f5bac..45254333915250c9366add9de94f626a3f6f5e65 100644 --- a/mace/python/tools/converter_tool/shape_inference.py +++ b/mace/python/tools/converter_tool/shape_inference.py @@ -224,7 +224,12 @@ class ShapeInference(object): def infer_shape_crop(self, op): mace_check(len(op.input) == 2, "crop layer needs two inputs") - output_shape = self._output_shape_cache[op.input[1]] + output_shape = self._output_shape_cache[op.input[0]] + input1_shape = self._output_shape_cache[op.input[1]] + offsets = ConverterUtil.get_arg(op, MaceKeyword.mace_offset_str).ints + for i in range(len(offsets)): + if offsets[i] >= 0: + output_shape[i] = input1_shape[i] self.add_output_shape(op, [output_shape]) def infer_shape_channel_shuffle(self, op): diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 62f482a6529a4bdc5afb053c0a8b2801cf29234e..d9fd30900978061237db03416b69e4fe8b80b316 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1010,7 +1010,8 @@ class Transformer(base_converter.ConverterInterface): elif filter_format == DataFormat.OIHW: weight.dims[:] = weight.dims[:] + [1, 1] else: - mace_check("FC does not support filter format %s", + mace_check(False, + "FC does not support filter format %s" % filter_format.name) return False @@ -1082,6 +1083,16 @@ class Transformer(base_converter.ConverterInterface): new_axises.sort() arg.ints[:] = [] arg.ints.extend(new_axises) + elif op.type == MaceOp.Crop.name: + offset_arg = ConverterUtil.get_arg(op, + MaceKeyword.mace_offset_str) + mace_check(offset_arg and + ConverterUtil.data_format(op) == DataFormat.NCHW and + len(op.output_shape[0].dims) == 4, + "MACE only support crop with NCHW format") + print("Transpose crop args: %s(%s)" + % (op.name, op.type)) + self.transpose_shape(offset_arg.ints, [0, 2, 3, 1]) # transpose op output shape data_format = ConverterUtil.data_format(op) @@ -1145,7 +1156,7 @@ class Transformer(base_converter.ConverterInterface): elif filter_format == DataFormat.OIHW: transpose_order = [0, 2, 3, 1] else: - mace_check("Quantize model does not support conv " + mace_check(False, "Quantize model does not support conv " "filter format: %s" % filter_format.name) for op in net.op: