test_helper.cc 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/mlu/bridges/test_helper.h"
#include <utility>
#include "lite/core/device_info.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/mlu/subgraph_compute.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {

D
dingminghui 已提交
27 28 29
template <lite_api::PrecisionType Dtype>
void PrepareInput(Graph* graph,
                  const std::string& input_name,
J
jiaopu 已提交
30 31
                  Tensor* input_tensor,
                  cnmlDataOrder_t order) {
D
dingminghui 已提交
32 33 34 35 36 37 38 39 40 41
  thread_local Tensor temp_input;
  temp_input.Resize(input_tensor->dims().Vectorize());
  temp_input.CopyDataFrom(*input_tensor);
  using data_type = typename MLUTypeTraits<Dtype>::type;
  auto input_node = graph->AddNode(
      input_name,
      input_tensor->dims().Vectorize(),
      CNML_TENSOR,
      CNML_NCHW,
      MLUTypeTraits<Dtype>::cnml_type,
J
jiaopu 已提交
42
      order,
D
dingminghui 已提交
43 44 45 46 47 48 49 50 51
      reinterpret_cast<void*>(
          input_tensor->template mutable_data<data_type>(TARGET(kMLU))));
  CHECK(input_node);
  CNRT_CHECK(cnrtMemcpy(input_tensor->template mutable_data<data_type>(),
                        temp_input.mutable_data<data_type>(),
                        sizeof(data_type) * input_tensor->dims().production(),
                        CNRT_MEM_TRANS_DIR_HOST2DEV));
}

52 53
void LaunchOp(const std::shared_ptr<lite::OpLite> op,
              const std::vector<std::string>& input_var_names,
J
jiaopu 已提交
54 55
              const std::vector<std::string>& output_var_names,
              cnmlDataOrder_t order) {
56
  CNRT_CALL(cnrtInit(0));
D
dingminghui 已提交
57
  lite::SetMluDevice(0);
58 59 60 61 62 63 64 65 66 67 68 69
  cnrtQueue_t queue_;
  CNRT_CALL(cnrtCreateQueue(&queue_));
  cnrtDev_t dev_handle;
  CNRT_CALL(cnrtGetDeviceHandle(&dev_handle, 0));
  CNRT_CALL(cnrtSetCurrentDevice(dev_handle));
  auto scope = op->scope();
  auto op_type = op->op_info()->Type();
  paddle::lite::subgraph::mlu::Graph graph;
  // convert op to IR graph
  const auto& bridges = subgraph::Registry::Instance();
  CHECK(bridges.Exists(op_type, TARGET(kMLU)));

J
jackzhang235 已提交
70
  // Convert input data var and add it into the MLU IR graph
71 72
  for (auto& input_name : input_var_names) {
    auto input_tensor = scope->FindMutableTensor(input_name);
J
jiaopu 已提交
73
    auto data_type = input_tensor->precision();
D
dingminghui 已提交
74

J
jiaopu 已提交
75
    switch (data_type) {
J
jiaopu 已提交
76 77 78
#define PREPARE_INPUT(type__)                                                 \
  case PRECISION(type__):                                                     \
    PrepareInput<PRECISION(type__)>(&graph, input_name, input_tensor, order); \
D
dingminghui 已提交
79 80 81 82 83 84
    break;
      PREPARE_INPUT(kFP16)
      PREPARE_INPUT(kFloat)
      PREPARE_INPUT(kInt8)
      PREPARE_INPUT(kInt32)
#undef PREPARE_INPUT
J
jiaopu 已提交
85 86 87
      default:
        CHECK(0);
    }
88
  }
J
jackzhang235 已提交
89 90
  op->CheckShape();
  op->InferShape();
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  bridges.Select(op_type, TARGET(kMLU))(
      reinterpret_cast<void*>(&graph), const_cast<OpLite*>(op.get()), nullptr);

  for (auto& output_name : output_var_names) {
    if (graph.HasNode(output_name)) {
      graph.AddOutput(graph.GetNode(output_name));
    }
    auto output_tensor = scope->FindMutableTensor(output_name);
    void* p_data =
        static_cast<void*>(output_tensor->mutable_data<float>(TARGET(kMLU)));
    auto node = graph.GetNode(output_name);
    CHECK(p_data);
    node->set_mlu_ptr(p_data);
  }
  for (auto& input_name : input_var_names) {
    graph.AddInput(graph.GetNode(input_name));
  }

  graph.Compile(CNML_MLU270, 1);
110
  graph.Compute(queue_, *(graph.MutableInputs()), *(graph.MutableOutputs()));
J
jackzhang235 已提交
111 112
  CNRT_CALL(cnrtSyncQueue(queue_));

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  for (auto& output_name : output_var_names) {
    auto output_tensor = scope->FindMutableTensor(output_name);
    Tensor temp_out;
    temp_out.Resize(output_tensor->dims().Vectorize());
    CNRT_CHECK(cnrtMemcpy(temp_out.mutable_data<float>(TARGET(kHost)),
                          output_tensor->mutable_data<float>(),
                          sizeof(float) * output_tensor->dims().production(),
                          CNRT_MEM_TRANS_DIR_DEV2HOST));
    output_tensor->mutable_data<float>(TARGET(kHost));
    output_tensor->CopyDataFrom(temp_out);
  }
}

}  // namespace mlu
}  // namespace subgraph
}  // namespace lite
}  // namespace paddle

// USE_LITE_OP(graph_op);
// USE_LITE_KERNEL(graph_op, kMLU, kFloat, kNHWC, def);