提交 b7870140 编写于 作者: 刘琦

Merge branch 'fix-transpose_before_shape_on_gpu' into 'master'

remove transpose half and add transpose to buffer before shape() for gpu runtime.

See merge request !1123
...@@ -50,15 +50,6 @@ std::string TransformedName(const std::string &input_name, ...@@ -50,15 +50,6 @@ std::string TransformedName(const std::string &input_name,
return ss.str(); return ss.str();
} }
#ifdef MACE_ENABLE_OPENCL
bool TransformRequiredOp(const std::string &op_type) {
static const std::unordered_set<std::string> kNoTransformOp = {
"Shape", "InferConv2dShape"
};
return kNoTransformOp.count(op_type) == 0;
}
#endif // MACE_ENABLE_OPENCL
void BuildTransposeOpDef( void BuildTransposeOpDef(
const std::string &input_name, const std::string &input_name,
const std::string &output_name, const std::string &output_name,
...@@ -514,76 +505,73 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -514,76 +505,73 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
// (only support one kind of memory type for multiple outputs) // (only support one kind of memory type for multiple outputs)
op_registry_->GetInOutMemoryTypes(op_def->type(), context); op_registry_->GetInOutMemoryTypes(op_def->type(), context);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
// if op is memory-unused op, no transformation int input_size = op_def->input_size();
if (TransformRequiredOp(op_def->type())) { for (int i = 0; i < input_size; ++i) {
int input_size = op_def->input_size(); if (output_map->count(op_def->input(i)) == 0) {
for (int i = 0; i < input_size; ++i) { MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr
if (output_map->count(op_def->input(i)) == 0) { && ws_->GetTensor(op_def->input(i))->is_weight(),
MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr "Tensor ", op_def->input(i), " of ",
&& ws_->GetTensor(op_def->input(i))->is_weight(), op_def->name(), " not allocated");
"Tensor ", op_def->input(i), " of ", continue;
op_def->name(), " not allocated"); }
continue; auto &input_info = output_map->at(op_def->input(i));
} // check whether to do transform
auto &input_info = output_map->at(op_def->input(i)); MemoryType src_mem_type = input_info.mem_type;
// check whether to do transform MemoryType dst_mem_type = context->GetInputMemType(i);
MemoryType src_mem_type = input_info.mem_type; auto wanted_input_dtype = context->GetInputDataType(i);
MemoryType dst_mem_type = context->GetInputMemType(i); if (src_mem_type != dst_mem_type ||
auto wanted_input_dtype = context->GetInputDataType(i); (input_info.dtype != wanted_input_dtype &&
if (src_mem_type != dst_mem_type || (src_mem_type != MemoryType::CPU_BUFFER
(input_info.dtype != wanted_input_dtype && || dst_mem_type != MemoryType::CPU_BUFFER))) {
(src_mem_type != MemoryType::CPU_BUFFER auto transformed_name = TransformedName(op_def->input(i),
|| dst_mem_type != MemoryType::CPU_BUFFER))) { "mem_type",
auto transformed_name = TransformedName(op_def->input(i), dst_mem_type);
"mem_type", // check whether the tensor has been transformed
dst_mem_type); if (transformed_set->count(transformed_name) == 0) {
// check whether the tensor has been transformed VLOG(1) << "Add Transform operation " << op_def->name()
if (transformed_set->count(transformed_name) == 0) { << " to transform tensor "
VLOG(1) << "Add Transform operation " << op_def->name() << op_def->input(i) << "', from memory type "
<< " to transform tensor " << input_info.mem_type << " to "
<< op_def->input(i) << "', from memory type " << dst_mem_type;
<< input_info.mem_type << " to " OperatorDef *transformed_op_def = target_net_def->add_op();
<< dst_mem_type; OpenCLUtil::BuildTransformOpDef(
OperatorDef *transformed_op_def = target_net_def->add_op(); op_def->input(i),
OpenCLUtil::BuildTransformOpDef( input_info.shape,
op_def->input(i), transformed_name,
input_info.shape, wanted_input_dtype,
transformed_name, context->GetInputOpenCLBufferType(i),
wanted_input_dtype, dst_mem_type,
context->GetInputOpenCLBufferType(i), input_info.data_format,
dst_mem_type, transformed_op_def);
input_info.data_format, // set data format arg
transformed_op_def); SetProtoArg<int>(transformed_op_def,
// set data format arg "data_format",
SetProtoArg<int>(transformed_op_def, static_cast<int>(input_info.data_format));
"data_format", // set output memory type argument
static_cast<int>(input_info.data_format)); SetProtoArg<int>(transformed_op_def,
// set output memory type argument OutputMemoryTypeTagName(),
SetProtoArg<int>(transformed_op_def, dst_mem_type);
OutputMemoryTypeTagName(),
dst_mem_type);
// update tensor consumer information // update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back( output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1); target_net_def->op_size() - 1);
// update output information map // update output information map
output_map->emplace( output_map->emplace(
transformed_name, transformed_name,
InternalOutputInfo( InternalOutputInfo(
dst_mem_type, dst_mem_type,
context->GetInputDataType(i), context->GetInputDataType(i),
input_info.data_format, input_info.data_format,
input_info.shape, input_info.shape,
target_net_def->op_size() - 1)); target_net_def->op_size() - 1));
// update tensor shape map // update tensor shape map
tensor_shape_map->emplace(transformed_name, input_info.shape); tensor_shape_map->emplace(transformed_name, input_info.shape);
// record transformed tensors // record transformed tensors
transformed_set->insert(transformed_name); transformed_set->insert(transformed_name);
}
// update original op_def's input
op_def->set_input(i, transformed_name);
} }
// update original op_def's input
op_def->set_input(i, transformed_name);
} }
} }
#else #else
......
...@@ -27,7 +27,10 @@ namespace mace { ...@@ -27,7 +27,10 @@ namespace mace {
namespace ops { namespace ops {
template<DeviceType D, typename T> template<DeviceType D, typename T>
class TransposeOp : public Operation { class TransposeOp;
template<DeviceType D>
class TransposeOp<D, float> : public Operation {
public: public:
explicit TransposeOp(OpConstructContext *context) explicit TransposeOp(OpConstructContext *context)
: Operation(context), : Operation(context),
...@@ -49,8 +52,8 @@ class TransposeOp : public Operation { ...@@ -49,8 +52,8 @@ class TransposeOp : public Operation {
Tensor::MappingGuard input_guard(input); Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>(); const float *input_data = input->data<float>();
T *output_data = output->mutable_data<T>(); float *output_data = output->mutable_data<float>();
return Transpose(&context->device()->cpu_runtime()->thread_pool(), return Transpose(&context->device()->cpu_runtime()->thread_pool(),
input_data, input->shape(), dims_, output_data); input_data, input->shape(), dims_, output_data);
...@@ -63,8 +66,6 @@ class TransposeOp : public Operation { ...@@ -63,8 +66,6 @@ class TransposeOp : public Operation {
void RegisterTranspose(OpRegistryBase *op_registry) { void RegisterTranspose(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp, MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp,
DeviceType::CPU, half);
} }
} // namespace ops } // namespace ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册