diff --git a/mace/examples/cli/README.md b/mace/examples/cli/README.md index ba70a7e61dcaef7a550a68262cc872fc333142f5..50e64f950e80afa1cb72199df3f68e0c0e7b518b 100644 --- a/mace/examples/cli/README.md +++ b/mace/examples/cli/README.md @@ -14,7 +14,7 @@ python tools/converter.py run --config=/path/to/your/model_deployment_file --exa * Validate result ``` -python tools/converter.py run --config=/path/to/your/model_deployment_file --example --example +python tools/converter.py run --config=/path/to/your/model_deployment_file --example --validate ``` * Check the logs diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index dd6d16d6351810c544263aadf5e0a7abbe24fcb3..74d8776fa089168c87ea7b1751244d3151e28492 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -129,12 +129,14 @@ MaceStatus WinogradTransformFunctor::operator()( template MaceStatus WinogradInverseTransformFunctor::operator()( - const Tensor *input_tensor, - const Tensor *bias, + const std::vector &inputs, Tensor *output_tensor, StatsFuture *future) { auto runtime = OpenCLRuntime::Global(); + const Tensor *input_tensor = inputs[0]; + const Tensor *bias = inputs.size() == 3 ? inputs[2] : nullptr; + if (kernel_.get() == nullptr) { std::string obfuscated_kernel_name; std::set built_options; @@ -191,18 +193,23 @@ MaceStatus WinogradInverseTransformFunctor::operator()( static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); } + Tensor::MappingGuard output_shape_guard(inputs[1]); + const int32_t *output_shape_data = inputs[1]->data(); + const index_t batch = output_shape_data[0]; + const index_t height = output_shape_data[1]; + const index_t width = output_shape_data[2]; const uint32_t gws[2] = { static_cast(input_tensor->dim(2)), static_cast(RoundUpDiv4(input_tensor->dim(1)))}; if (!IsVecEqual(input_shape_, input_tensor->shape())) { - std::vector output_shape = {batch_, height_, width_, + std::vector output_shape = {batch, height, width, input_tensor->dim(1)}; std::vector image_shape; CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); MACE_RETURN_IF_ERROR(output_tensor->ResizeImage(output_shape, image_shape)); - const index_t round_h = (height_ + wino_blk_size_ - 1) / wino_blk_size_; - const index_t round_w = (width_ + wino_blk_size_ - 1) / wino_blk_size_; + const index_t round_h = (height + wino_blk_size_ - 1) / wino_blk_size_; + const index_t round_w = (width + wino_blk_size_ - 1) / wino_blk_size_; const float round_hw_r = 1.f / static_cast(round_h * round_w); const float round_w_r = 1.f / static_cast(round_w); diff --git a/mace/kernels/winograd_transform.h b/mace/kernels/winograd_transform.h index 49ecb8492e8e1ec896863f14785a43aea61ebc3f..c7d6fc1aaf681d6a02e33dfc374da4dadcf6e6fb 100644 --- a/mace/kernels/winograd_transform.h +++ b/mace/kernels/winograd_transform.h @@ -86,22 +86,13 @@ struct WinogradTransformFunctor #endif // MACE_ENABLE_OPENCL struct WinogradInverseTransformFunctorBase { - WinogradInverseTransformFunctorBase(const int batch, - const int height, - const int width, - const ActivationType activation, + WinogradInverseTransformFunctorBase(const ActivationType activation, const float relux_max_limit, const int block_size) - : batch_(batch), - height_(height), - width_(width), - wino_blk_size_(block_size), + : wino_blk_size_(block_size), activation_(activation), relux_max_limit_(relux_max_limit) {} - const int batch_; - const int height_; - const int width_; const int wino_blk_size_; const ActivationType activation_; const float relux_max_limit_; @@ -109,21 +100,16 @@ struct WinogradInverseTransformFunctorBase { template struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { - WinogradInverseTransformFunctor(const int batch, - const int height, - const int width, - const ActivationType activation, + WinogradInverseTransformFunctor(const ActivationType activation, const float relux_max_limit, const int block_size) : WinogradInverseTransformFunctorBase( - batch, height, width, activation, relux_max_limit, block_size) {} + activation, relux_max_limit, block_size) {} - MaceStatus operator()(const Tensor *input, - const Tensor *bias, + MaceStatus operator()(const std::vector &inputs, Tensor *output, StatsFuture *future) { - MACE_UNUSED(input); - MACE_UNUSED(bias); + MACE_UNUSED(inputs); MACE_UNUSED(output); MACE_UNUSED(future); MACE_NOT_IMPLEMENTED; @@ -135,17 +121,13 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { template struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { - WinogradInverseTransformFunctor(const int batch, - const int height, - const int width, - const ActivationType activation, + WinogradInverseTransformFunctor(const ActivationType activation, const float relux_max_limit, const int block_size) : WinogradInverseTransformFunctorBase( - batch, height, width, activation, relux_max_limit, block_size) {} + activation, relux_max_limit, block_size) {} - MaceStatus operator()(const Tensor *input, - const Tensor *bias, + MaceStatus operator()(const std::vector &inputs, Tensor *output, StatsFuture *future); diff --git a/mace/ops/infer_conv2d_shape.cc b/mace/ops/infer_conv2d_shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..26aec062354d6f65d699f02f4ef976fba118fa97 --- /dev/null +++ b/mace/ops/infer_conv2d_shape.cc @@ -0,0 +1,46 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/infer_conv2d_shape.h" + +namespace mace { +namespace ops { + +void Register_InferConv2dShape(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + InferConv2dShapeOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + InferConv2dShapeOp); +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + InferConv2dShapeOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("InferConv2dShape") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + InferConv2dShapeOp); +#endif +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/infer_conv2d_shape.h b/mace/ops/infer_conv2d_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..bc6163c170524800a5e0bbe5d83b7c419aeb123b --- /dev/null +++ b/mace/ops/infer_conv2d_shape.h @@ -0,0 +1,113 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_INFER_CONV2D_SHAPE_H_ +#define MACE_OPS_INFER_CONV2D_SHAPE_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { +namespace ops { + +template +class InferConv2dShapeOp : public Operator { + public: + InferConv2dShapeOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4); + output->Resize({input->dim_size()}); + Tensor::MappingGuard output_guard(output); + int32_t *output_data = output->mutable_data(); + + const int32_t data_format = + OperatorBase::GetOptionalArg("data_format", 0); + const bool isNCHW = data_format == 1; + + const Padding padding_type = + static_cast(OperatorBase::GetOptionalArg( + "padding", static_cast(SAME))); + const std::vector paddings = + OperatorBase::GetRepeatedArgs("padding_values"); + const std::vector kernels = + OperatorBase::GetRepeatedArgs("kernels"); + const std::vector strides = + OperatorBase::GetRepeatedArgs("strides", {1, 1}); + const int32_t out_batch = static_cast(input->dim(0)); + const int32_t out_channel = static_cast(kernels[0]); + + int32_t in_h = 0, in_w = 0, in_c = 0; + if (isNCHW) { // NCHW + in_c = static_cast(input->dim(1)); + in_h = static_cast(input->dim(2)); + in_w = static_cast(input->dim(3)); + } else { + in_h = static_cast(input->dim(1)); + in_w = static_cast(input->dim(2)); + in_c = static_cast(input->dim(3)); + } + MACE_CHECK(in_c == kernels[1], + "different number of input channels between input and kernel"); + int32_t out_h = 0, out_w = 0; + if (!paddings.empty()) { + out_h = (in_h - kernels[2] + paddings[0]) / strides[0] + 1; + out_w = (in_w - kernels[3] + paddings[1]) / strides[1] + 1; + } else { + switch (padding_type) { + case SAME: + out_h = (in_h + strides[0] - 1) / strides[0]; + out_w = (in_w + strides[1] - 1) / strides[1]; + break; + case VALID: + out_h = (in_h - kernels[2] + 1) / strides[0]; + out_w = (in_w - kernels[3] + 1) / strides[1]; + break; + default: + MACE_NOT_IMPLEMENTED; + break; + } + } + + if (isNCHW) { + output_data[0] = out_batch; + output_data[1] = out_channel; + output_data[2] = out_h; + output_data[3] = out_w; + } else { + output_data[0] = out_batch; + output_data[1] = out_h; + output_data[2] = out_w; + output_data[3] = out_channel; + } + + SetFutureDefaultWaitFn(future); + + return MACE_SUCCESS; + } + + private: + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_INFER_CONV2D_SHAPE_H_ diff --git a/mace/ops/infer_conv2d_shape_test.cc b/mace/ops/infer_conv2d_shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f2e0b769e237ab1e04eef07114480718e867b00 --- /dev/null +++ b/mace/ops/infer_conv2d_shape_test.cc @@ -0,0 +1,67 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { +namespace ops { +namespace test { + +class InferConv2dShapeOpTest : public OpsTestBase {}; + +namespace { + +void TestInferConv2dShapeOp(const std::vector &input_shape, + const int stride, + const std::vector &output_shape) { + OpsTestNet net; + net.AddRandomInput("Input", input_shape); + const int in_ch = static_cast(input_shape[3]); + const int out_ch = static_cast(output_shape[3]); + OpDefBuilder("InferConv2dShape", "InferConv2dShapeOpTest") + .Input("Input") + .Output("Output") + .AddIntArg("datd_format", 0) + .AddIntsArg("strides", {stride, stride}) + .AddIntsArg("kernels", {out_ch, in_ch, 3, 3}) + .AddIntArg("padding", Padding::SAME) + .OutputType({DataTypeToEnum::v()}) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + std::vector expected_output_shape(output_shape.begin(), + output_shape.end()); + net.AddInputFromArray("ExpectedOutput", + {static_cast( + output_shape.size())}, + expected_output_shape); + + + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + +} // namespace + +TEST_F(InferConv2dShapeOpTest, TestInferConv2dShape) { +TestInferConv2dShapeOp({3, 640, 480, 16}, 1, {3, 640, 480, 3}); +TestInferConv2dShapeOp({3, 640, 480, 16}, 2, {3, 320, 240, 3}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index c318eb4417165ecfe6a2aa49339af3fa98093964..a798015380b4d933320c85b2e48b801ea6c793be 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -39,6 +39,7 @@ extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry); extern void Register_Gather(OperatorRegistryBase *op_registry); extern void Register_Identity(OperatorRegistryBase *op_registry); +extern void Register_InferConv2dShape(OperatorRegistryBase *op_registry); extern void Register_LocalResponseNorm(OperatorRegistryBase *op_registry); extern void Register_MatMul(OperatorRegistryBase *op_registry); extern void Register_Pad(OperatorRegistryBase *op_registry); @@ -91,6 +92,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_FullyConnected(this); ops::Register_Gather(this); ops::Register_Identity(this); + ops::Register_InferConv2dShape(this); ops::Register_LocalResponseNorm(this); ops::Register_MatMul(this); ops::Register_Pad(this); diff --git a/mace/ops/winograd_convolution_benchmark.cc b/mace/ops/winograd_convolution_benchmark.cc index 37f7d960ec8e51b5eab93069a15562ef58d99dc0..c616a28072adc2634bea33628be3a45c1ac5779a 100644 --- a/mace/ops/winograd_convolution_benchmark.cc +++ b/mace/ops/winograd_convolution_benchmark.cc @@ -43,7 +43,21 @@ void BMWinogradConvolution( // transform filter BufferToImage(&net, "Filter", "WinoFilter", kernels::BufferType::WINOGRAD_FILTER, block_size); - // transform input + + // Inference convolution output shape + OpDefBuilder("InferConv2dShape", "InferConv2dShapeTest") + .Input("InputImage") + .Output("ShapeOutput") + .AddIntArg("data_format", 0) + .AddIntsArg("strides", {1, 1}) + .AddIntsArg("kernels", {static_cast(out_channels), + static_cast(in_channels), + 3, 3}) + .AddIntArg("padding", Padding::SAME) + .OutputType({DataTypeToEnum::v()}) + .Finalize(net.NewOperatorDef()); + + // Transform input OpDefBuilder("WinogradTransform", "WinogradTransformTest") .Input("InputImage") .Output("WinoInput") @@ -63,6 +77,7 @@ void BMWinogradConvolution( // Inverse transform OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") .Input("WinoGemm") + .Input("ShapeOutput") .Input("BiasImage") .AddIntArg("batch", batch) .AddIntArg("height", height) diff --git a/mace/ops/winograd_convolution_test.cc b/mace/ops/winograd_convolution_test.cc index c2bd6b1210ab58595300b02f00644fe0b324e3f9..2406a3614a3acb49788c2bc2ac72338e068b0a1a 100644 --- a/mace/ops/winograd_convolution_test.cc +++ b/mace/ops/winograd_convolution_test.cc @@ -84,6 +84,19 @@ void WinogradConvolution(const index_t batch, // Run on opencl net.RunOp(D); + OpDefBuilder("InferConv2dShape", "InferConv2dShapeTest") + .Input("InputImage") + .Output("ShapeOutput") + .AddIntArg("data_format", 0) + .AddIntsArg("strides", {1, 1}) + .AddIntsArg("kernels", {static_cast(out_channels), + static_cast(in_channels), + 3, 3}) + .AddIntArg("padding", padding) + .OutputType({DataTypeToEnum::v()}) + .Finalize(net.NewOperatorDef()); + net.RunOp(D); + // MatMul OpDefBuilder("MatMul", "MatMulTest") .Input("WinoFilter") @@ -97,10 +110,8 @@ void WinogradConvolution(const index_t batch, // Inverse transform OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") .Input("WinoGemm") + .Input("ShapeOutput") .Input("BiasImage") - .AddIntArg("batch", batch) - .AddIntArg("height", output_shape[1]) - .AddIntArg("width", output_shape[2]) .AddIntArg("wino_block_size", block_size) .Output("WinoOutputImage") .AddIntArg("T", static_cast(DataTypeToEnum::value)) @@ -221,6 +232,19 @@ void WinogradConvolutionWithPad(const index_t batch, // Run on opencl net.RunOp(D); + OpDefBuilder("InferConv2dShape", "InferConv2dShapeTest") + .Input("InputImage") + .Output("ShapeOutput") + .AddIntArg("data_format", 0) + .AddIntsArg("strides", {1, 1}) + .AddIntsArg("kernels", {static_cast(out_channels), + static_cast(in_channels), + 3, 3}) + .AddIntsArg("padding_values", {padding, padding}) + .OutputType({DataTypeToEnum::v()}) + .Finalize(net.NewOperatorDef()); + net.RunOp(D); + // MatMul OpDefBuilder("MatMul", "MatMulTest") .Input("WinoFilter") @@ -234,10 +258,8 @@ void WinogradConvolutionWithPad(const index_t batch, // Inverse transform OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") .Input("WinoGemm") + .Input("ShapeOutput") .Input("BiasImage") - .AddIntArg("batch", batch) - .AddIntArg("height", output_shape[1]) - .AddIntArg("width", output_shape[2]) .AddIntArg("wino_block_size", block_size) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Output("WinoOutputImage") diff --git a/mace/ops/winograd_inverse_transform.h b/mace/ops/winograd_inverse_transform.h index 2dfa2f50cdf145e1ebf6bbfb17df8e170fb76902..0349de8ace51322cdc715c9bc81ee3c4ec21b2bb 100644 --- a/mace/ops/winograd_inverse_transform.h +++ b/mace/ops/winograd_inverse_transform.h @@ -17,6 +17,7 @@ #include #include +#include #include "mace/core/operator.h" #include "mace/kernels/activation.h" @@ -30,27 +31,22 @@ class WinogradInverseTransformOp : public Operator { public: WinogradInverseTransformOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), - functor_(OperatorBase::GetOptionalArg("batch", 1), - OperatorBase::GetOptionalArg("height", 0), - OperatorBase::GetOptionalArg("width", 0), - kernels::StringToActivationType( + functor_(kernels::StringToActivationType( OperatorBase::GetOptionalArg("activation", "NOOP")), OperatorBase::GetOptionalArg("max_limit", 0.0f), OperatorBase::GetOptionalArg("wino_block_size", 2)) {} MaceStatus Run(StatsFuture *future) override { - const Tensor *input_tensor = this->Input(INPUT); - const Tensor *bias = this->InputSize() == 2 ? this->Input(BIAS) : nullptr; + const std::vector &inputs = this->Inputs(); Tensor *output_tensor = this->Output(OUTPUT); - return functor_(input_tensor, bias, output_tensor, future); + return functor_(inputs, output_tensor, future); } private: kernels::WinogradInverseTransformFunctor functor_; protected: - MACE_OP_INPUT_TAGS(INPUT, BIAS); MACE_OP_OUTPUT_TAGS(OUTPUT); }; diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 99fac06f63600bb58d350f0be857e03ef83932f6..7f873dda2b7ac68a5db162fb5e12073ce54638a2 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -92,6 +92,7 @@ MaceSupportedOps = [ 'FullyConnected', 'Gather', 'Identity', + 'InferConv2dShape', 'LocalResponseNorm', 'MatMul', 'Pad', diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index dbb9d605842574b34c787ddf63679fecd4661c1f..dfa92bf3eeca5d63127a0456dfa69b0aa9dd2905 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -560,11 +560,47 @@ class Transformer(base_converter.ConverterInterface): arg.name = MaceKeyword.mace_winograd_filter_transformed arg.i = 1 + shape_op = net.op.add() + shape_op.name = op.name + '_infer_shape' + shape_op.type = MaceOp.InferConv2dShape.name + shape_op.input.extend([op.input[0]]) + shape_op.output.extend([shape_op.name]) + shape_output_shape = shape_op.output_shape.add() + shape_output_shape.dims.extend([4]) + + kernels_arg = shape_op.arg.add() + kernels_arg.name = MaceKeyword.mace_kernel_str + kernels_arg.ints.extend([out_channels, + in_channels, + filter_height, + filter_width]) + + if data_format is not None: + data_format_arg = shape_op.arg.add() + data_format_arg.name = MaceKeyword.mace_data_format_str + data_format_arg.i = data_format.value + + if ConverterUtil.get_arg(op, + MaceKeyword.mace_padding_str) \ + is not None: + padding_arg = shape_op.arg.add() + padding_arg.name = MaceKeyword.mace_padding_str + padding_arg.i = ConverterUtil.get_arg( + op, MaceKeyword.mace_padding_str).i + elif ConverterUtil.get_arg( + op, MaceKeyword.mace_padding_values_str) \ + is not None: + padding_arg = shape_op.arg.add() + padding_arg.name = MaceKeyword.mace_padding_values_str + padding_arg.ints.extend(ConverterUtil.get_arg( + op, MaceKeyword.mace_padding_values_str).ints) + # Inverse transform iwt_op = net.op.add() iwt_op.name = op.name + '_inverse_transform' iwt_op.type = MaceOp.WinogradInverseTransform.name iwt_op.input.extend([matmul_op.output[0]]) + iwt_op.input.extend([shape_op.output[0]]) # biasadd if len(op.input) >= 3: iwt_op.input.extend([op.input[2]]) @@ -572,15 +608,6 @@ class Transformer(base_converter.ConverterInterface): iwt_output_shape = iwt_op.output_shape.add() iwt_output_shape.dims.extend(op.output_shape[0].dims) - batch_arg = iwt_op.arg.add() - batch_arg.name = 'batch' - batch_arg.i = batch - height_arg = iwt_op.arg.add() - height_arg.name = 'height' - height_arg.i = out_height - width_arg = iwt_op.arg.add() - width_arg.name = 'width' - width_arg.i = out_width blk_size_arg = iwt_op.arg.add() blk_size_arg.name = MaceKeyword.mace_wino_block_size blk_size_arg.i = block_size @@ -1146,7 +1173,7 @@ class Transformer(base_converter.ConverterInterface): MaceKeyword.mace_winograd_filter_transformed) is not None: # noqa self.buffer_to_image(op, 0, OpenCLBufferType.WINOGRAD_FILTER) elif op.type == MaceOp.WinogradInverseTransform.name \ - and len(op.input) >= 2: + and len(op.input) >= 3: self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) elif op.type == MaceOp.FullyConnected.name: self.buffer_to_image(op, 1, OpenCLBufferType.WEIGHT_WIDTH) diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index c0f1ddd022671e8b4db291301abc1b08aa4fe255..d7ba945080200d55822024d04f4fe8237a3ae56a 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -228,7 +228,11 @@ class GPUMemoryOptimizer(MemoryOptimizer): mace_pb2.GPU_IMAGE, calculate_image_shape(OpenCLBufferType.IN_OUT_HEIGHT, buffer_shape)) - elif op_type in ['Shape', 'StridedSlice', 'Stack', 'ScalarMath']: + elif op_type in ['Shape', + 'InferConv2dShape', + 'StridedSlice', + 'Stack', + 'ScalarMath']: if len(output_shape) == 1: mem_block = MemoryBlock(mace_pb2.CPU_BUFFER, [output_shape[0], 1])