diff --git a/mace/core/memory_optimizer.cc b/mace/core/memory_optimizer.cc index 03e79f1e979aa77fcda8c5bbfb7ffa81fef486d6..6d08ca907b55b209cc9b5d05b7638bbf9d27fd18 100644 --- a/mace/core/memory_optimizer.cc +++ b/mace/core/memory_optimizer.cc @@ -31,11 +31,19 @@ namespace mace { -bool MemoryOptimizer::IsMemoryReuseOp(const std::string &op_type) { +bool MemoryOptimizer::IsMemoryReuseOp(const std::string &op_type, + const MemoryType mem_type) { static const std::unordered_set kReuseOp = { "Reshape", "Identity", "Squeeze", "ExpandDims" }; - return kReuseOp.count(op_type) == 1; + static const std::unordered_set kGpuImageReuseOp = { + "Identity", "Squeeze", "ExpandDims" + }; + if (mem_type == MemoryType::GPU_IMAGE) { + return kGpuImageReuseOp.count(op_type) == 1; + } else { + return kReuseOp.count(op_type) == 1; + } } void MemoryOptimizer::UpdateTensorRef(const std::string &tensor_name) { @@ -142,7 +150,7 @@ void MemoryOptimizer::Optimize( } MemoryBlock op_mem_block = CreateMemoryBlock(op_def, i, dt, mem_type); MemoryBlock best_mem_block; - if (IsMemoryReuseOp(op_def->type())) { + if (IsMemoryReuseOp(op_def->type(), mem_type)) { if (tensor_mem_map_.count(op_def->input(0)) == 1) { best_mem_id = tensor_mem_map_.at(op_def->input(0)).mem_id; } diff --git a/mace/core/memory_optimizer.h b/mace/core/memory_optimizer.h index b4e635f54f8c1e74328803793a58ff20ceeefbf0..8cbfc5dc0d135af536eef4677e7a720f375ba7a1 100644 --- a/mace/core/memory_optimizer.h +++ b/mace/core/memory_optimizer.h @@ -90,7 +90,8 @@ class MemoryOptimizer { }; public: - static bool IsMemoryReuseOp(const std::string &op_type); + static bool IsMemoryReuseOp(const std::string &op_type, + const MemoryType mem_type); void UpdateTensorRef(const std::string &tensor_name); void UpdateTensorRef(const OperatorDef *op_def); void Optimize( diff --git a/mace/ops/opencl/buffer/reshape.cc b/mace/ops/opencl/buffer/reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae3c119c2368d4c57d2151a641472d508999151b --- /dev/null +++ b/mace/ops/opencl/buffer/reshape.cc @@ -0,0 +1,40 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/opencl/buffer/reshape.h" + +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { +namespace opencl { +namespace buffer { + +MaceStatus ReshapeKernel::Compute(OpContext *context, + const Tensor *input, + const std::vector &new_shape, + Tensor *output) { + MACE_UNUSED(context); + output->ReuseTensorBuffer(*input); + output->Reshape(new_shape); + + return MaceStatus::MACE_SUCCESS; +} + +} // namespace buffer +} // namespace opencl +} // namespace ops +} // namespace mace diff --git a/mace/ops/opencl/buffer/reshape.h b/mace/ops/opencl/buffer/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..f030f1e759b8fc7bf837d4b1054c062dcdac8338 --- /dev/null +++ b/mace/ops/opencl/buffer/reshape.h @@ -0,0 +1,44 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_OPENCL_BUFFER_RESHAPE_H_ +#define MACE_OPS_OPENCL_BUFFER_RESHAPE_H_ + +#include "mace/ops/opencl/reshape.h" + +#include + +#include "mace/ops/opencl/helper.h" + +namespace mace { +namespace ops { +namespace opencl { +namespace buffer { + +class ReshapeKernel : public OpenCLReshapeKernel { + public: + ReshapeKernel() {} + + MaceStatus Compute(OpContext *context, + const Tensor *input, + const std::vector &new_shape, + Tensor *output) override; +}; + +} // namespace buffer +} // namespace opencl +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_OPENCL_BUFFER_RESHAPE_H_ diff --git a/mace/ops/opencl/image/reshape.cc b/mace/ops/opencl/image/reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..58ad3af94a7e457f78192992f0434e3bd329399d --- /dev/null +++ b/mace/ops/opencl/image/reshape.cc @@ -0,0 +1,60 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/opencl/image/reshape.h" + +#include +#include + +#include "mace/ops/opencl/image/buffer_to_image.h" +#include "mace/ops/opencl/image/image_to_buffer.h" +#include "mace/utils/memory.h" + +namespace mace { +namespace ops { +namespace opencl { +namespace image { + +ReshapeKernel::ReshapeKernel(OpConstructContext *context) { + i2bkernel_ = make_unique(); + b2ikernel_ = make_unique(); + inter_buffer_ = + make_unique(context->device()->allocator(), DT_FLOAT); + MACE_CHECK(inter_buffer_ != nullptr); +} + +MaceStatus ReshapeKernel::Compute(OpContext *context, + const Tensor *input, + const std::vector &new_shape, + Tensor *output) { + MaceStatus succ = i2bkernel_->Compute(context, input, + OpenCLBufferType::IN_OUT_CHANNEL, + 0, inter_buffer_.get()); + MACE_RETURN_IF_ERROR(succ); + + succ = inter_buffer_->Resize(new_shape); + MACE_RETURN_IF_ERROR(succ); + + succ = b2ikernel_->Compute(context, inter_buffer_.get(), + OpenCLBufferType::IN_OUT_CHANNEL, + 0, output); + MACE_RETURN_IF_ERROR(succ); + + return MaceStatus::MACE_SUCCESS; +} + +} // namespace image +} // namespace opencl +} // namespace ops +} // namespace mace diff --git a/mace/ops/opencl/image/reshape.h b/mace/ops/opencl/image/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..4004fb5e904c4105f3ae5615e8dde37c557a62e1 --- /dev/null +++ b/mace/ops/opencl/image/reshape.h @@ -0,0 +1,52 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_OPENCL_IMAGE_RESHAPE_H_ +#define MACE_OPS_OPENCL_IMAGE_RESHAPE_H_ + +#include "mace/ops/opencl/reshape.h" + +#include +#include + +#include "mace/core/operator.h" +#include "mace/ops/opencl/helper.h" +#include "mace/ops/opencl/buffer_transform_kernel.h" + +namespace mace { +namespace ops { +namespace opencl { +namespace image { + +class ReshapeKernel : public OpenCLReshapeKernel { + public: + explicit ReshapeKernel(OpConstructContext *context); + + MaceStatus Compute(OpContext *context, + const Tensor *input, + const std::vector &new_shape, + Tensor *output) override; + + private: + std::unique_ptr inter_buffer_; + std::unique_ptr i2bkernel_; + std::unique_ptr b2ikernel_; +}; + +} // namespace image +} // namespace opencl +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_OPENCL_IMAGE_RESHAPE_H_ diff --git a/mace/ops/opencl/reshape.h b/mace/ops/opencl/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..e389ab2ad44d57c0091700fbfa3eee3bd0ca0ff2 --- /dev/null +++ b/mace/ops/opencl/reshape.h @@ -0,0 +1,45 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef MACE_OPS_OPENCL_RESHAPE_H_ +#define MACE_OPS_OPENCL_RESHAPE_H_ + +#include + +#include "mace/core/types.h" +#include "mace/public/mace.h" +#include "mace/utils/math.h" + + +namespace mace { + +class OpContext; +class Tensor; + +namespace ops { + +class OpenCLReshapeKernel { + public: + virtual MaceStatus Compute(OpContext *context, + const Tensor *input, + const std::vector &new_shape, + Tensor *output) = 0; + MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLReshapeKernel); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_OPENCL_RESHAPE_H_ diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 98ea215e7678b32170bf98d415b0c88ec23a60e6..bd9a69a55f45f8c937a7cf80fc6ab3b4200e04e4 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -17,10 +17,61 @@ #include "mace/core/operator.h" #include "mace/utils/math.h" +#ifdef MACE_ENABLE_OPENCL +#include "mace/ops/opencl/image/reshape.h" +#include "mace/ops/opencl/buffer/reshape.h" +#endif + namespace mace { namespace ops { -template +namespace { + +MaceStatus GetOutputShape(const Tensor *input, + const int32_t *shape_data, + const index_t num_dims, + std::vector *out_shape) { + MACE_CHECK(input != nullptr && shape_data != nullptr && out_shape != nullptr); + int unknown_idx = -1; + index_t product = 1; + index_t n = 0; + + out_shape->resize(num_dims); + for (int i = 0; i < num_dims; ++i) { + if (shape_data[i] == -1) { + MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); + unknown_idx = i; + (*out_shape)[i] = 1; + } else { + MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", + shape_data[i]); + if (shape_data[i] == 0) { + MACE_CHECK(i < input->dim_size(), + "dims:0 out of input dims' range."); + n = input->dim(i); + } else { + n = shape_data[i]; + } + (*out_shape)[i] = n; + product *= n; + } + } + + if (unknown_idx != -1) { + MACE_CHECK(product != 0) + << "Cannot infer shape if there is zero shape size."; + const index_t missing = input->size() / product; + MACE_CHECK(missing * product == input->size()) + << "Input size not match reshaped tensor size"; + (*out_shape)[unknown_idx] = missing; + } + + return MaceStatus::MACE_SUCCESS; +} + +} // namespace + +template class ReshapeOp : public Operation { public: explicit ReshapeOp(OpConstructContext *context) @@ -31,46 +82,14 @@ class ReshapeOp : public Operation { MACE_UNUSED(context); const Tensor *input = this->Input(INPUT); const Tensor *shape = this->Input(SHAPE); - const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); Tensor::MappingGuard shape_guard(shape); const int32_t *shape_data = shape->data(); + const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); + std::vector out_shape; + MACE_RETURN_IF_ERROR( + GetOutputShape(input, shape_data, num_dims, &out_shape)); - int unknown_idx = -1; - index_t product = 1; - std::vector out_shape(num_dims); - index_t n = 0; - - for (int i = 0; i < num_dims; ++i) { - if (shape_data[i] == -1) { - MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); - unknown_idx = i; - out_shape[i] = 1; - } else { - MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", - shape_data[i]); - if (shape_data[i] == 0) { - MACE_CHECK(i < input->dim_size(), - "dims:0 out of input dims' range."); - n = input->dim(i); - } else { - n = shape_data[i]; - } - out_shape[i] = n; - product *= n; - } - } - - if (unknown_idx != -1) { - MACE_CHECK(product != 0) - << "Cannot infer shape if there is zero shape size."; - const index_t missing = input->size() / product; - MACE_CHECK(missing * product == input->size()) - << "Input size not match reshaped tensor size"; - out_shape[unknown_idx] = missing; - } - Tensor *output = this->Output(OUTPUT); // NHWC -> NCHW - if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 && shape->is_weight()) { std::vector dst_dims = {0, 3, 1, 2}; @@ -79,6 +98,7 @@ class ReshapeOp : public Operation { out_shape = trans_shape; } + Tensor *output = this->Output(OUTPUT); output->ReuseTensorBuffer(*input); output->Reshape(out_shape); @@ -93,11 +113,46 @@ class ReshapeOp : public Operation { MACE_OP_OUTPUT_TAGS(OUTPUT); }; +#ifdef MACE_ENABLE_OPENCL +template<> +class ReshapeOp : public Operation { + public: + explicit ReshapeOp(OpConstructContext *context) + : Operation(context), + dim_(Operation::GetRepeatedArgs("dim")) { + if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { + kernel_ = make_unique(context); + } else { + kernel_ = make_unique(); + } + } + + MaceStatus Run(OpContext *context) override { + const Tensor *input = this->Input(INPUT); + const int32_t *shape_data = dim_.data(); + const index_t num_dims = dim_.size(); + std::vector out_shape; + MACE_RETURN_IF_ERROR( + GetOutputShape(input, shape_data, num_dims, &out_shape)); + + Tensor *output = this->Output(OUTPUT); + return kernel_->Compute(context, input, out_shape, output); + } + + private: + std::vector dim_; + std::unique_ptr kernel_; + MACE_OP_INPUT_TAGS(INPUT, SHAPE); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; +#endif + void RegisterReshape(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, int32_t); + MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp); } } // namespace ops diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 691ead78356811d9449a37ac0ce08bff1cc697bc..12baa73118fce24917ce469607fe10c1afeb8ab1 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -179,6 +179,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation, MaceOp.Eltwise, MaceOp.Pad, MaceOp.Reduce, + MaceOp.Reshape, MaceOp.Softmax, MaceOp.Split, MaceOp.Squeeze, @@ -300,7 +301,7 @@ class TransformerRule(Enum): FOLD_SQRDIFF_MEAN = 33 TRANSPOSE_MATMUL_WEIGHT = 34 FOLD_EMBEDDING_LOOKUP = 35 - TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36 + TRANSPOSE_RESHAPE_AND_FLATTEN = 36 FOLD_FC_RESHAPE = 37 TRANSFORM_CHANNEL_SHUFFLE = 38 UPDATE_DATA_FORMAT = 39 @@ -517,7 +518,7 @@ class ConverterOption(object): TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, TransformerRule.TRANSFORM_BASIC_LSTMCELL, - TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN, + TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN, TransformerRule.FOLD_RESHAPE, TransformerRule.TRANSFORM_MATMUL_TO_FC, # For StoB -> conv -> BtoS -> BN pattern diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 69cdcbf26bddb6c6a9e3866354d210c6a08016e6..fa0fcb450a21afc72174726f294c3160e5d6bfb8 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -97,11 +97,11 @@ class Transformer(base_converter.ConverterInterface): self.add_opencl_informations, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, TransformerRule.UPDATE_DATA_FORMAT: self.update_data_format, + TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN: + self.transform_reshape_and_flatten, TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.CHECK_QUANTIZE_INFO: self.check_quantize_info, - TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN: - self.transform_caffe_reshape_and_flatten, TransformerRule.TRANSFORM_CHANNEL_SHUFFLE: self.transform_channel_shuffle, TransformerRule.QUANTIZE_SPECIFIC_OPS_ONLY: @@ -1493,6 +1493,13 @@ class Transformer(base_converter.ConverterInterface): print("Transpose crop args: %s(%s)" % (op.name, op.type)) self.transpose_shape(offset_arg.ints, [0, 2, 3, 1]) + elif op.type == MaceOp.Reshape.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_dim_str and \ + len(arg.ints) == 4 and \ + src_data_format == DataFormat.NCHW and \ + has_data_format: + self.transpose_shape(arg.ints, [0, 2, 3, 1]) # transpose op output shape if src_data_format == DataFormat.NCHW and \ @@ -2048,14 +2055,16 @@ class Transformer(base_converter.ConverterInterface): arg.i = mace_pb2.GPU_IMAGE if self._option.cl_mem_type == "image"\ else mace_pb2.GPU_BUFFER - def transform_caffe_reshape_and_flatten(self): + def transform_reshape_and_flatten(self): net = self._model for op in net.op: - if op.type == MaceOp.Reshape.name and \ - len(op.input) == 1: + if op.type != MaceOp.Reshape.name: + continue + dim_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str) + shape_tensor = None + if len(op.input) == 1: print("Transform Caffe Reshape") dims = [] - dim_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str) axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) # transform caffe reshape op if dim_arg: @@ -2080,6 +2089,13 @@ class Transformer(base_converter.ConverterInterface): mace_check(False, "Only support reshape and flatten") shape_tensor.int32_data.extend(dims) op.input.append(shape_tensor.name) + if len(op.input) == 2 and dim_arg is None: + if shape_tensor is None and op.input[1] in self._consts: + shape_tensor = self._consts[op.input[1]] + if shape_tensor is not None: + dim_arg = op.arg.add() + dim_arg.name = MaceKeyword.mace_dim_str + dim_arg.ints.extend(shape_tensor.int32_data) def fold_fc_reshape(self): net = self._model