From aa52f08fcc87c9c36675af1e2046b28cbd661643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=B6=E5=89=91=E6=AD=A6?= Date: Mon, 21 Oct 2019 16:47:39 +0800 Subject: [PATCH] support fallback from opencl to cpu in ReshapeOp --- mace/ops/reshape.cc | 46 ++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index bd9a69a5..32597fc9 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -18,8 +18,8 @@ #include "mace/utils/math.h" #ifdef MACE_ENABLE_OPENCL -#include "mace/ops/opencl/image/reshape.h" #include "mace/ops/opencl/buffer/reshape.h" +#include "mace/ops/opencl/image/reshape.h" #endif namespace mace { @@ -46,8 +46,7 @@ MaceStatus GetOutputShape(const Tensor *input, MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", shape_data[i]); if (shape_data[i] == 0) { - MACE_CHECK(i < input->dim_size(), - "dims:0 out of input dims' range."); + MACE_CHECK(i < input->dim_size(), "dims:0 out of input dims' range."); n = input->dim(i); } else { n = shape_data[i]; @@ -59,10 +58,10 @@ MaceStatus GetOutputShape(const Tensor *input, if (unknown_idx != -1) { MACE_CHECK(product != 0) - << "Cannot infer shape if there is zero shape size."; + << "Cannot infer shape if there is zero shape size."; const index_t missing = input->size() / product; MACE_CHECK(missing * product == input->size()) - << "Input size not match reshaped tensor size"; + << "Input size not match reshaped tensor size"; (*out_shape)[unknown_idx] = missing; } @@ -71,7 +70,7 @@ MaceStatus GetOutputShape(const Tensor *input, } // namespace -template +template class ReshapeOp : public Operation { public: explicit ReshapeOp(OpConstructContext *context) @@ -90,11 +89,11 @@ class ReshapeOp : public Operation { GetOutputShape(input, shape_data, num_dims, &out_shape)); // NHWC -> NCHW - if (has_df_ && D == DeviceType::CPU - && out_shape.size() == 4 && shape->is_weight()) { + if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 && + shape->is_weight()) { std::vector dst_dims = {0, 3, 1, 2}; - std::vector trans_shape = TransposeShape( - out_shape, dst_dims); + std::vector trans_shape = + TransposeShape(out_shape, dst_dims); out_shape = trans_shape; } @@ -114,12 +113,11 @@ class ReshapeOp : public Operation { }; #ifdef MACE_ENABLE_OPENCL -template<> +template <> class ReshapeOp : public Operation { public: explicit ReshapeOp(OpConstructContext *context) - : Operation(context), - dim_(Operation::GetRepeatedArgs("dim")) { + : Operation(context), dim_(Operation::GetRepeatedArgs("dim")) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { kernel_ = make_unique(context); } else { @@ -148,11 +146,25 @@ class ReshapeOp : public Operation { #endif void RegisterReshape(OpRegistryBase *op_registry) { - MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, - DeviceType::CPU, float); - MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, - DeviceType::CPU, int32_t); + MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, float); + MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, int32_t); MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp); + MACE_REGISTER_OP_CONDITION( + op_registry, OpConditionBuilder("Reshape").SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return {DeviceType::CPU, DeviceType::GPU}; + } + + auto tensor_shape_info = context->tensor_shape_info(); + const std::string &input_0 = op->input(0); + if (4 == op->output_shape(0).dims_size() && + 4 == tensor_shape_info->at(input_0).size()) { + return {DeviceType::CPU, DeviceType::GPU}; + } + return {DeviceType::CPU}; + })); } } // namespace ops -- GitLab