diff --git a/mace/core/net_def_adapter.cc b/mace/core/net_def_adapter.cc index 2e450b09c75c76a3508518b21dba90386f84631c..205dcdbe47374b92082a102eeef84dfe149794f3 100644 --- a/mace/core/net_def_adapter.cc +++ b/mace/core/net_def_adapter.cc @@ -50,15 +50,6 @@ std::string TransformedName(const std::string &input_name, return ss.str(); } -#ifdef MACE_ENABLE_OPENCL -bool TransformRequiredOp(const std::string &op_type) { - static const std::unordered_set kNoTransformOp = { - "Shape", "InferConv2dShape" - }; - return kNoTransformOp.count(op_type) == 0; -} -#endif // MACE_ENABLE_OPENCL - void BuildTransposeOpDef( const std::string &input_name, const std::string &output_name, @@ -514,76 +505,73 @@ MaceStatus NetDefAdapter::AdaptMemoryType( // (only support one kind of memory type for multiple outputs) op_registry_->GetInOutMemoryTypes(op_def->type(), context); #ifdef MACE_ENABLE_OPENCL - // if op is memory-unused op, no transformation - if (TransformRequiredOp(op_def->type())) { - int input_size = op_def->input_size(); - for (int i = 0; i < input_size; ++i) { - if (output_map->count(op_def->input(i)) == 0) { - MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr - && ws_->GetTensor(op_def->input(i))->is_weight(), - "Tensor ", op_def->input(i), " of ", - op_def->name(), " not allocated"); - continue; - } - auto &input_info = output_map->at(op_def->input(i)); - // check whether to do transform - MemoryType src_mem_type = input_info.mem_type; - MemoryType dst_mem_type = context->GetInputMemType(i); - auto wanted_input_dtype = context->GetInputDataType(i); - if (src_mem_type != dst_mem_type || - (input_info.dtype != wanted_input_dtype && - (src_mem_type != MemoryType::CPU_BUFFER - || dst_mem_type != MemoryType::CPU_BUFFER))) { - auto transformed_name = TransformedName(op_def->input(i), - "mem_type", - dst_mem_type); - // check whether the tensor has been transformed - if (transformed_set->count(transformed_name) == 0) { - VLOG(1) << "Add Transform operation " << op_def->name() - << " to transform tensor " - << op_def->input(i) << "', from memory type " - << input_info.mem_type << " to " - << dst_mem_type; - OperatorDef *transformed_op_def = target_net_def->add_op(); - OpenCLUtil::BuildTransformOpDef( - op_def->input(i), - input_info.shape, - transformed_name, - wanted_input_dtype, - context->GetInputOpenCLBufferType(i), - dst_mem_type, - input_info.data_format, - transformed_op_def); - // set data format arg - SetProtoArg(transformed_op_def, - "data_format", - static_cast(input_info.data_format)); - // set output memory type argument - SetProtoArg(transformed_op_def, - OutputMemoryTypeTagName(), - dst_mem_type); + int input_size = op_def->input_size(); + for (int i = 0; i < input_size; ++i) { + if (output_map->count(op_def->input(i)) == 0) { + MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr + && ws_->GetTensor(op_def->input(i))->is_weight(), + "Tensor ", op_def->input(i), " of ", + op_def->name(), " not allocated"); + continue; + } + auto &input_info = output_map->at(op_def->input(i)); + // check whether to do transform + MemoryType src_mem_type = input_info.mem_type; + MemoryType dst_mem_type = context->GetInputMemType(i); + auto wanted_input_dtype = context->GetInputDataType(i); + if (src_mem_type != dst_mem_type || + (input_info.dtype != wanted_input_dtype && + (src_mem_type != MemoryType::CPU_BUFFER + || dst_mem_type != MemoryType::CPU_BUFFER))) { + auto transformed_name = TransformedName(op_def->input(i), + "mem_type", + dst_mem_type); + // check whether the tensor has been transformed + if (transformed_set->count(transformed_name) == 0) { + VLOG(1) << "Add Transform operation " << op_def->name() + << " to transform tensor " + << op_def->input(i) << "', from memory type " + << input_info.mem_type << " to " + << dst_mem_type; + OperatorDef *transformed_op_def = target_net_def->add_op(); + OpenCLUtil::BuildTransformOpDef( + op_def->input(i), + input_info.shape, + transformed_name, + wanted_input_dtype, + context->GetInputOpenCLBufferType(i), + dst_mem_type, + input_info.data_format, + transformed_op_def); + // set data format arg + SetProtoArg(transformed_op_def, + "data_format", + static_cast(input_info.data_format)); + // set output memory type argument + SetProtoArg(transformed_op_def, + OutputMemoryTypeTagName(), + dst_mem_type); - // update tensor consumer information - output_map->at(op_def->input(i)).consumer_op_indices.push_back( - target_net_def->op_size() - 1); + // update tensor consumer information + output_map->at(op_def->input(i)).consumer_op_indices.push_back( + target_net_def->op_size() - 1); - // update output information map - output_map->emplace( - transformed_name, - InternalOutputInfo( - dst_mem_type, - context->GetInputDataType(i), - input_info.data_format, - input_info.shape, - target_net_def->op_size() - 1)); - // update tensor shape map - tensor_shape_map->emplace(transformed_name, input_info.shape); - // record transformed tensors - transformed_set->insert(transformed_name); - } - // update original op_def's input - op_def->set_input(i, transformed_name); + // update output information map + output_map->emplace( + transformed_name, + InternalOutputInfo( + dst_mem_type, + context->GetInputDataType(i), + input_info.data_format, + input_info.shape, + target_net_def->op_size() - 1)); + // update tensor shape map + tensor_shape_map->emplace(transformed_name, input_info.shape); + // record transformed tensors + transformed_set->insert(transformed_name); } + // update original op_def's input + op_def->set_input(i, transformed_name); } } #else diff --git a/mace/ops/transpose.cc b/mace/ops/transpose.cc index 22f60e2846ac0b34a1a843ffc8a772be532767c8..6c6993e065a9dbf1f0a0bf0e336ea32598a9989b 100644 --- a/mace/ops/transpose.cc +++ b/mace/ops/transpose.cc @@ -27,7 +27,10 @@ namespace mace { namespace ops { template -class TransposeOp : public Operation { +class TransposeOp; + +template +class TransposeOp : public Operation { public: explicit TransposeOp(OpConstructContext *context) : Operation(context), @@ -49,8 +52,8 @@ class TransposeOp : public Operation { Tensor::MappingGuard input_guard(input); Tensor::MappingGuard output_guard(output); - const T *input_data = input->data(); - T *output_data = output->mutable_data(); + const float *input_data = input->data(); + float *output_data = output->mutable_data(); return Transpose(&context->device()->cpu_runtime()->thread_pool(), input_data, input->shape(), dims_, output_data); @@ -63,8 +66,6 @@ class TransposeOp : public Operation { void RegisterTranspose(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp, DeviceType::CPU, float); - MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp, - DeviceType::CPU, half); } } // namespace ops