// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/kernels/mlu/bridges/test_helper.h" #include #include "lite/core/device_info.h" #include "lite/core/op_registry.h" #include "lite/kernels/mlu/bridges/utility.h" #include "lite/kernels/mlu/subgraph_compute.h" #include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { namespace subgraph { namespace mlu { template void PrepareInput(Graph* graph, const std::string& input_name, Tensor* input_tensor) { thread_local Tensor temp_input; temp_input.Resize(input_tensor->dims().Vectorize()); temp_input.CopyDataFrom(*input_tensor); using data_type = typename MLUTypeTraits::type; auto input_node = graph->AddNode( input_name, input_tensor->dims().Vectorize(), CNML_TENSOR, CNML_NCHW, MLUTypeTraits::cnml_type, CNML_NHWC, reinterpret_cast( input_tensor->template mutable_data(TARGET(kMLU)))); CHECK(input_node); CNRT_CHECK(cnrtMemcpy(input_tensor->template mutable_data(), temp_input.mutable_data(), sizeof(data_type) * input_tensor->dims().production(), CNRT_MEM_TRANS_DIR_HOST2DEV)); } void LaunchOp(const std::shared_ptr op, const std::vector& input_var_names, const std::vector& output_var_names) { CNRT_CALL(cnrtInit(0)); lite::SetMluDevice(0); cnrtQueue_t queue_; cnrtInvokeFuncParam_t forward_param; u32_t affinity = 1; int data_param = 1; forward_param.data_parallelism = &data_param; forward_param.affinity = &affinity; forward_param.end = CNRT_PARAM_END; CNRT_CALL(cnrtCreateQueue(&queue_)); cnrtDev_t dev_handle; CNRT_CALL(cnrtGetDeviceHandle(&dev_handle, 0)); CNRT_CALL(cnrtSetCurrentDevice(dev_handle)); auto scope = op->scope(); auto op_type = op->op_info()->Type(); paddle::lite::subgraph::mlu::Graph graph; // convert op to IR graph const auto& bridges = subgraph::Registry::Instance(); CHECK(bridges.Exists(op_type, TARGET(kMLU))); // Convert input data var and add it into the MLU IR graph for (auto& input_name : input_var_names) { auto input_tensor = scope->FindMutableTensor(input_name); auto data_type = input_tensor->precision(); switch (data_type) { #define PREPARE_INPUT(type__) \ case PRECISION(type__): \ PrepareInput(&graph, input_name, input_tensor); \ break; PREPARE_INPUT(kFP16) PREPARE_INPUT(kFloat) PREPARE_INPUT(kInt8) PREPARE_INPUT(kInt32) #undef PREPARE_INPUT default: CHECK(0); } } op->CheckShape(); op->InferShape(); bridges.Select(op_type, TARGET(kMLU))( reinterpret_cast(&graph), const_cast(op.get()), nullptr); for (auto& output_name : output_var_names) { if (graph.HasNode(output_name)) { graph.AddOutput(graph.GetNode(output_name)); } auto output_tensor = scope->FindMutableTensor(output_name); void* p_data = static_cast(output_tensor->mutable_data(TARGET(kMLU))); auto node = graph.GetNode(output_name); CHECK(p_data); node->set_mlu_ptr(p_data); } for (auto& input_name : input_var_names) { graph.AddInput(graph.GetNode(input_name)); } graph.Compile(CNML_MLU270, 1); graph.Compute(forward_param, queue_, *(graph.MutableInputs()), *(graph.MutableOutputs())); CNRT_CALL(cnrtSyncQueue(queue_)); for (auto& output_name : output_var_names) { auto output_tensor = scope->FindMutableTensor(output_name); Tensor temp_out; temp_out.Resize(output_tensor->dims().Vectorize()); CNRT_CHECK(cnrtMemcpy(temp_out.mutable_data(TARGET(kHost)), output_tensor->mutable_data(), sizeof(float) * output_tensor->dims().production(), CNRT_MEM_TRANS_DIR_DEV2HOST)); output_tensor->mutable_data(TARGET(kHost)); output_tensor->CopyDataFrom(temp_out); } } } // namespace mlu } // namespace subgraph } // namespace lite } // namespace paddle // USE_LITE_OP(graph_op); // USE_LITE_KERNEL(graph_op, kMLU, kFloat, kNHWC, def);