From e9d1103b37f9cb983f87449c8670ed55564c7a16 Mon Sep 17 00:00:00 2001 From: liutuo Date: Mon, 27 Aug 2018 16:47:56 +0800 Subject: [PATCH] fold shape_stridedslice_stack --- mace/kernels/deconv_2d.h | 78 +++++++++---------- mace/kernels/opencl/cl/deconv_2d.cl | 40 +++++----- mace/kernels/opencl/deconv_2d.cc | 64 +++++++++------ mace/ops/deconv_2d.h | 15 ++-- mace/ops/deconv_2d_test.cc | 24 +++--- mace/ops/shape.h | 16 +++- .../tools/converter_tool/base_converter.py | 1 - .../tools/converter_tool/caffe_converter.py | 4 - 8 files changed, 126 insertions(+), 116 deletions(-) diff --git a/mace/kernels/deconv_2d.h b/mace/kernels/deconv_2d.h index ad527a84..9450104d 100644 --- a/mace/kernels/deconv_2d.h +++ b/mace/kernels/deconv_2d.h @@ -90,20 +90,18 @@ void Deconv2dNCHW(const T *input, } // namespace deconv struct Deconv2dFunctorBase { - Deconv2dFunctorBase(const int *strides, + Deconv2dFunctorBase(const std::vector &strides, const Padding &padding_type, const std::vector &paddings, const std::vector &output_shape, const ActivationType activation, - const float relux_max_limit, - const bool from_caffe) + const float relux_max_limit) : strides_(strides), padding_type_(padding_type), paddings_(paddings), output_shape_(output_shape), activation_(activation), - relux_max_limit_(relux_max_limit), - from_caffe_(from_caffe) {} + relux_max_limit_(relux_max_limit) {} static void CalcDeconvOutputSize( const index_t *input_shape, // NHWC @@ -202,31 +200,28 @@ struct Deconv2dFunctorBase { padding_size[1] = std::max(0, p_w); } - const int *strides_; // [stride_h, stride_w] + std::vector strides_; // [stride_h, stride_w] const Padding padding_type_; std::vector paddings_; std::vector output_shape_; const ActivationType activation_; const float relux_max_limit_; - const bool from_caffe_; }; template struct Deconv2dFunctor : Deconv2dFunctorBase { - Deconv2dFunctor(const int *strides, + Deconv2dFunctor(const std::vector &strides, const Padding &padding_type, const std::vector &paddings, const std::vector &output_shape, const ActivationType activation, - const float relux_max_limit, - const bool from_caffe) + const float relux_max_limit) : Deconv2dFunctorBase(strides, padding_type, paddings, output_shape, activation, - relux_max_limit, - from_caffe) {} + relux_max_limit) {} MaceStatus operator()(const Tensor *input, // NCHW const Tensor *filter, // OIHW @@ -239,13 +234,12 @@ struct Deconv2dFunctor : Deconv2dFunctorBase { MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(output); - if (!from_caffe_) { // tensorflow - std::vector output_shape(4); + std::vector paddings(2); + std::vector output_shape(4); + if (paddings_.empty()) { // tensorflow + paddings = std::vector(2, 0); if (output_shape_.size() == 4) { - output_shape[0] = output_shape_[0]; - output_shape[1] = output_shape_[3]; - output_shape[2] = output_shape_[1]; - output_shape[3] = output_shape_[2]; + output_shape = output_shape_; } else { MACE_CHECK_NOTNULL(output_shape_tensor); MACE_CHECK(output_shape_tensor->size() == 4); @@ -255,36 +249,38 @@ struct Deconv2dFunctor : Deconv2dFunctorBase { output_shape = std::vector(output_shape_data, output_shape_data + 4); } - paddings_.clear(); - paddings_ = std::vector(2, 0); + const index_t t = output_shape[1]; + output_shape[1] = output_shape[3]; + output_shape[3] = output_shape[2]; + output_shape[2] = t; + CalcDeconvPaddingAndInputSize( input->shape().data(), filter->shape().data(), - strides_, padding_type_, + strides_.data(), padding_type_, output_shape.data(), - paddings_.data(), true); - MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + paddings.data(), true); } else { // caffe - output_shape_.clear(); - output_shape_ = std::vector(4, 0); + paddings = paddings_; + output_shape = std::vector(4, 0); CalcDeconvOutputSize(input->shape().data(), filter->shape().data(), - strides_, - output_shape_.data(), - paddings_.data(), true); - MACE_RETURN_IF_ERROR(output->Resize(output_shape_)); + strides_.data(), + output_shape.data(), + paddings.data(), true); } + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); index_t kernel_h = filter->dim(2); index_t kernel_w = filter->dim(3); const index_t *in_shape = input->shape().data(); - const index_t *out_shape = output->shape().data(); const index_t kernel_hw[2] = {kernel_h, kernel_w}; - MACE_CHECK(filter->dim(0) == out_shape[1], filter->dim(0), " != ", - out_shape[1]); + MACE_CHECK(filter->dim(0) == output_shape[1], filter->dim(0), " != ", + output_shape[1]); MACE_CHECK(filter->dim(1) == in_shape[1], filter->dim(1), " != ", in_shape[1]); - MACE_CHECK(in_shape[0] == out_shape[0], "Input/Output batch size mismatch"); + MACE_CHECK(in_shape[0] == output_shape[0], + "Input/Output batch size mismatch"); Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard filter_mapper(filter); Tensor::MappingGuard bias_mapper(bias); @@ -294,15 +290,15 @@ struct Deconv2dFunctor : Deconv2dFunctorBase { auto bias_data = bias == nullptr ? nullptr : bias->data(); auto output_data = output->mutable_data(); int padding[2]; - padding[0] = (paddings_[0] + 1) >> 1; - padding[1] = (paddings_[1] + 1) >> 1; + padding[0] = (paddings[0] + 1) >> 1; + padding[1] = (paddings[1] + 1) >> 1; deconv::Deconv2dNCHW(input_data, filter_data, bias_data, in_shape, - out_shape, + output_shape.data(), kernel_hw, - strides_, + strides_.data(), padding, output_data); @@ -319,20 +315,18 @@ struct Deconv2dFunctor : Deconv2dFunctorBase { #ifdef MACE_ENABLE_OPENCL template struct Deconv2dFunctor : Deconv2dFunctorBase { - Deconv2dFunctor(const int *strides, + Deconv2dFunctor(const std::vector &strides, const Padding &padding_type, const std::vector &paddings, const std::vector &output_shape, const ActivationType activation, - const float relux_max_limit, - const bool from_caffe) + const float relux_max_limit) : Deconv2dFunctorBase(strides, padding_type, paddings, output_shape, activation, - relux_max_limit, - from_caffe) {} + relux_max_limit) {} MaceStatus operator()(const Tensor *input, const Tensor *filter, diff --git a/mace/kernels/opencl/cl/deconv_2d.cl b/mace/kernels/opencl/cl/deconv_2d.cl index f64dc7bb..d7728f3e 100644 --- a/mace/kernels/opencl/cl/deconv_2d.cl +++ b/mace/kernels/opencl/cl/deconv_2d.cl @@ -15,8 +15,10 @@ __kernel void deconv_2d(KERNEL_ERROR_PARAMS __private const int out_height, __private const int out_width, __private const int out_channel, - __private const int stride, - __private const float stride_r, + __private const int stride_h, + __private const int stride_w, + __private const float stride_h_r, + __private const float stride_w_r, __private const int align_h, __private const int align_w, __private const int padding_h, @@ -53,18 +55,18 @@ __kernel void deconv_2d(KERNEL_ERROR_PARAMS DATA_TYPE4 out4 = 0; #endif - const int n_stride = mad(w_id, stride_r, 0); - const int mod_stride = w_id - mul24(n_stride, stride); - const int w = mad24(mul24(n_stride, 5), stride, mod_stride); + const int n_stride = mad(w_id, stride_w_r, 0); + const int mod_stride = w_id - mul24(n_stride, stride_w); + const int w = mad24(mul24(n_stride, 5), stride_w, mod_stride); const int b = hb / out_height; const int h = hb - mul24(b, out_height); if (w < out_width) { - int start_x = floor((float) (w + align_w) * stride_r); - int start_y = (h + align_h) * stride_r; + int start_x = floor((float) (w + align_w) * stride_w_r); + int start_y = (h + align_h) * stride_h_r; start_y = max(0, start_y); - int f_start_x = mad24(start_x, stride, padding_w) - w; - int f_start_y = mad24(start_y, stride, padding_h) - h; + int f_start_x = mad24(start_x, stride_w, padding_w) - w; + int f_start_y = mad24(start_y, stride_h, padding_h) - h; f_start_x = kernel_w - 1 - f_start_x; f_start_y = kernel_h - 1 - f_start_y; @@ -79,10 +81,10 @@ __kernel void deconv_2d(KERNEL_ERROR_PARAMS f_pos_x1 = f_pos_x0 + 1; f_pos_x2 = f_pos_x0 + 2; f_pos_x3 = f_pos_x0 + 3; - for (int f_y = f_start_y, idx_h = start_y ; f_y >= 0; f_y -= stride, ++idx_h) { + for (int f_y = f_start_y, idx_h = start_y ; f_y >= 0; f_y -= stride_h, ++idx_h) { index_y = mad24(b, in_height, idx_h); in_pos.y = select(index_y, -1, idx_h < 0 || idx_h >= in_height); - for (int f_x = f_start_x, idx_w = start_x; f_x >= 0; f_x -= stride, ++idx_w) { + for (int f_x = f_start_x, idx_w = start_x; f_x >= 0; f_x -= stride_w, ++idx_w) { f_pos_y = mad24(f_y, kernel_w, f_x); f_pos_y = mad24(c, kernel_size, f_pos_y); weight0 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x0, f_pos_y)); @@ -141,24 +143,24 @@ __kernel void deconv_2d(KERNEL_ERROR_PARAMS out_pos.x = mad24(c, out_width, ow); WRITE_IMAGET(output, out_pos, out0); - ow += stride; + ow += stride_w; if (ow >= out_width) return; - out_pos.x += stride; + out_pos.x += stride_w; WRITE_IMAGET(output, out_pos, out1); - ow += stride; + ow += stride_w; if (ow >= out_width) return; - out_pos.x += stride; + out_pos.x += stride_w; WRITE_IMAGET(output, out_pos, out2); - ow += stride; + ow += stride_w; if (ow >= out_width) return; - out_pos.x += stride; + out_pos.x += stride_w; WRITE_IMAGET(output, out_pos, out3); - ow += stride; + ow += stride_w; if (ow >= out_width) return; - out_pos.x += stride; + out_pos.x += stride_w; WRITE_IMAGET(output, out_pos, out4); } } \ No newline at end of file diff --git a/mace/kernels/opencl/deconv_2d.cc b/mace/kernels/opencl/deconv_2d.cc index 770d64ef..cba8cbce 100644 --- a/mace/kernels/opencl/deconv_2d.cc +++ b/mace/kernels/opencl/deconv_2d.cc @@ -24,7 +24,7 @@ MaceStatus Deconv2dOpencl(cl::Kernel *kernel, const Tensor *input, const Tensor *filter, const Tensor *bias, - const int stride, + const int *strides, const int *paddings, const ActivationType activation, const float relux_max_limit, @@ -42,17 +42,20 @@ MaceStatus Deconv2dOpencl(cl::Kernel *kernel, const index_t channel_blocks = RoundUpDiv4(channels); const index_t input_channel_blocks = RoundUpDiv4(input_channels); - MACE_CHECK(stride > 0, "stride should > 0."); + const int stride_h = strides[0]; + const int stride_w = strides[1]; + MACE_CHECK(stride_w > 0 && stride_h > 0, "strides should be > 0."); #define MACE_WIDTH_BLK 5 - const index_t n_strides = (width + stride - 1) / stride; + const index_t n_strides = (width + stride_w - 1) / stride_w; const index_t width_blocks = - ((n_strides + MACE_WIDTH_BLK - 1) / MACE_WIDTH_BLK) * stride; - const float stride_r = 1.f / static_cast(stride); + ((n_strides + MACE_WIDTH_BLK - 1) / MACE_WIDTH_BLK) * stride_w; + const float stride_h_r = 1.f / static_cast(stride_h); + const float stride_w_r = 1.f / static_cast(stride_w); const int padding_h = (paddings[0] + 1) >> 1; - const int padding_w = (paddings[0] + 1) >> 1; + const int padding_w = (paddings[1] + 1) >> 1; - const int align_h = stride - 1 - padding_h; - const int align_w = stride - 1 - padding_w; + const int align_h = stride_h - 1 - padding_h; + const int align_w = stride_w - 1 - padding_w; const int kernel_size = filter->dim(2) * filter->dim(3); auto runtime = OpenCLRuntime::Global(); @@ -113,8 +116,10 @@ MaceStatus Deconv2dOpencl(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(height)); kernel->setArg(idx++, static_cast(width)); kernel->setArg(idx++, static_cast(channels)); - kernel->setArg(idx++, static_cast(stride)); - kernel->setArg(idx++, stride_r); + kernel->setArg(idx++, static_cast(stride_h)); + kernel->setArg(idx++, static_cast(stride_w)); + kernel->setArg(idx++, stride_h_r); + kernel->setArg(idx++, stride_w_r); kernel->setArg(idx++, static_cast(align_h)); kernel->setArg(idx++, static_cast(align_w)); kernel->setArg(idx++, static_cast(padding_h)); @@ -152,34 +157,43 @@ MaceStatus Deconv2dFunctor::operator()( MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(output); - if (!from_caffe_) { + std::vector paddings(2); + std::vector output_shape(4); + if (paddings_.empty()) { + paddings = std::vector(2, 0); if (output_shape_.size() != 4) { MACE_CHECK_NOTNULL(output_shape_tensor); MACE_CHECK(output_shape_tensor->size() == 4); Tensor::MappingGuard output_shape_mapper(output_shape_tensor); auto output_shape_data = output_shape_tensor->data(); - output_shape_ = + output_shape = std::vector(output_shape_data, output_shape_data + 4); + } else { + output_shape = output_shape_; } - paddings_.clear(); - paddings_ = std::vector(2, 0); - CalcDeconvPaddingAndInputSize(input->shape().data(), filter->shape().data(), - strides_, padding_type_, output_shape_.data(), - paddings_.data()); + CalcDeconvPaddingAndInputSize(input->shape().data(), + filter->shape().data(), + strides_.data(), + padding_type_, + output_shape.data(), + paddings.data()); } else { - output_shape_.clear(); - output_shape_ = std::vector(4, 0); - CalcDeconvOutputSize(input->shape().data(), filter->shape().data(), - strides_, output_shape_.data(), paddings_.data()); + paddings = paddings_; + output_shape = std::vector(4, 0); + CalcDeconvOutputSize(input->shape().data(), + filter->shape().data(), + strides_.data(), + output_shape.data(), + paddings.data()); } std::vector output_image_shape; - CalImage2DShape(output_shape_, BufferType::IN_OUT_CHANNEL, + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &output_image_shape); - MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape_, output_image_shape)); + MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); - return Deconv2dOpencl(&kernel_, input, filter, bias, strides_[0], - paddings_.data(), activation_, relux_max_limit_, + return Deconv2dOpencl(&kernel_, input, filter, bias, strides_.data(), + paddings.data(), activation_, relux_max_limit_, DataTypeToEnum::value, &input_shape_, output, future, &kwg_size_, &kernel_error_); } diff --git a/mace/ops/deconv_2d.h b/mace/ops/deconv_2d.h index fae87ce9..188b8ba0 100644 --- a/mace/ops/deconv_2d.h +++ b/mace/ops/deconv_2d.h @@ -19,23 +19,22 @@ #include "mace/core/operator.h" #include "mace/kernels/deconv_2d.h" -#include "mace/ops/conv_pool_2d_base.h" namespace mace { namespace ops { template -class Deconv2dOp : public ConvPool2dOpBase { +class Deconv2dOp : public Operator { public: Deconv2dOp(const OperatorDef &op_def, Workspace *ws) - : ConvPool2dOpBase(op_def, ws), - functor_(this->strides_.data(), - this->padding_type_, - this->paddings_, + : Operator(op_def, ws), + functor_(OperatorBase::GetRepeatedArgs("strides"), + static_cast(OperatorBase::GetOptionalArg( + "padding", static_cast(SAME))), + OperatorBase::GetRepeatedArgs("padding_values"), OperatorBase::GetRepeatedArgs("output_shape"), kernels::ActivationType::NOOP, - 0.0f, - OperatorBase::GetOptionalArg("from_caffe", false)) {} + 0.0f) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/ops/deconv_2d_test.cc b/mace/ops/deconv_2d_test.cc index 9b13d5ca..954d6bf4 100644 --- a/mace/ops/deconv_2d_test.cc +++ b/mace/ops/deconv_2d_test.cc @@ -41,7 +41,6 @@ void RunTestSimple(const std::vector &input_shape, net.AddInputFromArray("Input", input_shape, input_data); net.AddInputFromArray("Filter", filter_shape, filter_data); net.TransformDataFormat("Filter", HWOI, "FilterOIHW", OIHW); - bool from_caffe = output_shape.size() != 4; if (D == DeviceType::GPU) { BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); @@ -55,7 +54,6 @@ void RunTestSimple(const std::vector &input_shape, .AddIntArg("padding", padding) .AddIntsArg("padding_values", padding_size) .AddIntsArg("output_shape", output_shape) - .AddIntArg("from_caffe", from_caffe) .Finalize(net.NewOperatorDef()); net.RunOp(D); @@ -74,7 +72,6 @@ void RunTestSimple(const std::vector &input_shape, .AddIntArg("padding", padding) .AddIntsArg("padding_values", padding_size) .AddIntsArg("output_shape", output_shape) - .AddIntArg("from_caffe", from_caffe) .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); @@ -89,7 +86,7 @@ void RunTestSimple(const std::vector &input_shape, template void TestNHWCSimple3x3SAME_S1() { RunTestSimple({1, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1}, 1, Padding::SAME, - {0, 0}, {1, 3, 3, 3}, {3, 3, 3, 1}, + {}, {1, 3, 3, 3}, {3, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, {1, 3, 3, 3}, {4, 4, 4, 6, 6, 6, 4, 4, 4, 6, 6, 6, 9, 9, @@ -101,7 +98,7 @@ void TestNHWCSimple3x3SAME_S1() { {1, 3, 3, 3}, {4, 4, 4, 6, 6, 6, 4, 4, 4, 6, 6, 6, 9, 9, 9, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 4, 4}); RunTestSimple({1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, 1, Padding::SAME, - {0, 0}, {1, 3, 3, 3}, {3, 3, 3, 1}, + {}, {1, 3, 3, 3}, {3, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}, {1, 3, 3, 3}, {54, 66, 78, 126, 147, 168, 130, 146, 162, @@ -119,7 +116,7 @@ void TestNHWCSimple3x3SAME_S1() { template void TestNHWCSimple3x3SAME_S2() { RunTestSimple( - {1, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1}, 2, Padding::SAME, {0, 0}, + {1, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1}, 2, Padding::SAME, {}, {1, 6, 6, 3}, {3, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, {1, 6, 6, 3}, @@ -137,7 +134,7 @@ void TestNHWCSimple3x3SAME_S2() { 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1}); RunTestSimple( - {1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, 2, Padding::SAME, {0, 0}, + {1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, 2, Padding::SAME, {}, {1, 6, 6, 3}, {3, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}, @@ -167,7 +164,7 @@ template void TestNHWCSimple3x3SAME_S2_1() { RunTestSimple( {1, 3, 3, 1}, {12, 18, 12, 18, 27, 18, 12, 18, 12}, 2, Padding::SAME, - {0, 0}, {1, 5, 5, 3}, {3, 3, 3, 1}, + {}, {1, 5, 5, 3}, {3, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, {1, 5, 5, 3}, @@ -181,7 +178,7 @@ void TestNHWCSimple3x3SAME_S2_1() { template void TestNHWCSimple3x3VALID_S2() { RunTestSimple( - {1, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1}, 2, Padding::VALID, {0, 0}, + {1, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1}, 2, Padding::VALID, {}, {1, 7, 7, 3}, {3, 3, 3, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, {1, 7, 7, 3}, @@ -197,7 +194,7 @@ void TestNHWCSimple3x3VALID_S2() { template void TestNHWCSimple3x3VALID_S1() { RunTestSimple( - {1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, 1, Padding::VALID, {0, 0}, + {1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, 1, Padding::VALID, {}, {1, 5, 5, 3}, {3, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}, @@ -212,7 +209,7 @@ void TestNHWCSimple3x3VALID_S1() { template void TestNHWCSimple2x2SAME() { - RunTestSimple({1, 2, 2, 1}, {1, 1, 1, 1}, 1, Padding::SAME, {0, 0}, + RunTestSimple({1, 2, 2, 1}, {1, 1, 1, 1}, 1, Padding::SAME, {}, {1, 2, 2, 1}, {3, 3, 1, 1}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, {1, 2, 2, 1}, {4.f, 4.f, 4.f, 4.f}); @@ -221,7 +218,7 @@ void TestNHWCSimple2x2SAME() { template void TestNHWCSimple2x2VALID() { RunTestSimple( - {1, 2, 2, 1}, {1, 1, 1, 1}, 2, Padding::VALID, {0, 0}, {1, 5, 5, 1}, + {1, 2, 2, 1}, {1, 1, 1, 1}, 2, Padding::VALID, {}, {1, 5, 5, 1}, {3, 3, 1, 1}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, {1, 5, 5, 1}, {1.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 2.f, 1.f, 1.f, 2.f, 2.f, 4.f, @@ -333,7 +330,6 @@ void TestComplexDeconvNxNS12(const int batch, paddings.push_back(padding); paddings.push_back(padding); } - bool from_caffe = output_shape.size() != 4; // Construct graph OpDefBuilder("Deconv2D", "Deconv2dTest") .Input("InputNCHW") @@ -344,7 +340,6 @@ void TestComplexDeconvNxNS12(const int batch, .AddIntArg("padding", type) .AddIntsArg("padding_values", paddings) .AddIntsArg("output_shape", output_shape) - .AddIntArg("from_caffe", from_caffe) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); @@ -375,7 +370,6 @@ void TestComplexDeconvNxNS12(const int batch, .AddIntArg("padding", type) .AddIntsArg("padding_values", paddings) .AddIntsArg("output_shape", output_shape) - .AddIntArg("from_caffe", from_caffe) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); // Run on device diff --git a/mace/ops/shape.h b/mace/ops/shape.h index 18d73e61..98f139e4 100644 --- a/mace/ops/shape.h +++ b/mace/ops/shape.h @@ -39,8 +39,20 @@ class ShapeOp : public Operator { Tensor::MappingGuard output_guard(output); int32_t *output_data = output->mutable_data(); - for (index_t i = 0; i < input->dim_size(); ++i) { - output_data[i] = input->dim(i); + const int data_format = + OperatorBase::GetOptionalArg("data_format", 0); + if (input->dim_size() == 4 && + D == DeviceType::CPU && + data_format == DataFormat::NCHW) { + // transpose NCHW to NHWC for cpu runtime + output_data[0] = static_cast(input->dim(0)); + output_data[1] = static_cast(input->dim(2)); + output_data[2] = static_cast(input->dim(3)); + output_data[3] = static_cast(input->dim(1)); + } else { + for (unsigned int i = 0; i < input->dim_size(); ++i) { + output_data[i] = static_cast(input->dim(i)); + } } SetFutureDefaultWaitFn(future); diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 7f873dda..37f753fd 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -167,7 +167,6 @@ class MaceKeyword(object): mace_transpose_b_str = 'transpose_b' mace_op_data_type_str = 'T' mace_offset_str = 'offset' - mace_from_caffe_str = 'from_caffe' mace_opencl_max_image_size = "opencl_max_image_size" mace_seperate_buffer_str = 'seperate_buffer' mace_scalar_input_index_str = 'scalar_input_index' diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index d9cc35aa..eea1a235 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -414,10 +414,6 @@ class CaffeConverter(base_converter.ConverterInterface): op.type = MaceOp.Deconv2D.name - from_caffe_arg = op.arg.add() - from_caffe_arg.name = MaceKeyword.mace_from_caffe_str - from_caffe_arg.i = 1 - self.add_stride_pad_kernel_arg(param, op) # dilation is specific for convolution in caffe dilations = [1, 1] -- GitLab