plugin_arg_mapping_context.cc 5.0 KB
Newer Older
W
weishengying 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"

namespace paddle {
namespace inference {
namespace tensorrt {

bool PluginArgumentMappingContext::HasInput(const std::string& name) const {
22
  auto inputs = op_desc_->Inputs();
W
weishengying 已提交
23 24 25 26 27 28 29
  for (auto& i : inputs) {
    if (i.first == name && !i.second.empty()) return true;
  }
  return false;
}

bool PluginArgumentMappingContext::HasOutput(const std::string& name) const {
30
  auto outputs = op_desc_->Outputs();
W
weishengying 已提交
31 32 33 34 35 36 37
  for (auto& i : outputs) {
    if (i.first == name && !i.second.empty()) return true;
  }
  return false;
}

bool PluginArgumentMappingContext::HasAttr(const std::string& name) const {
38
  return op_desc_->HasAttr(name);
W
weishengying 已提交
39 40 41 42
}

paddle::any PluginArgumentMappingContext::Attr(
    const std::string& attr_name) const {
43
  auto attr_type = op_desc_->GetAttrType(attr_name);
W
weishengying 已提交
44 45
  switch (attr_type) {
    case framework::proto::AttrType::INT: {
46
      return PADDLE_GET_CONST(int, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
47 48 49
      break;
    };
    case framework::proto::AttrType::FLOAT: {
50
      return PADDLE_GET_CONST(float, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
51 52 53
      break;
    };
    case framework::proto::AttrType::STRING: {
54
      return PADDLE_GET_CONST(std::string, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
55 56 57
      break;
    };
    case framework::proto::AttrType::INTS: {
58
      return PADDLE_GET_CONST(std::vector<int>, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
59 60 61
      break;
    };
    case framework::proto::AttrType::FLOATS: {
62
      return PADDLE_GET_CONST(std::vector<float>, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
63 64 65 66
      break;
    };
    case framework::proto::AttrType::STRINGS: {
      return PADDLE_GET_CONST(std::vector<std::string>,
67
                              op_desc_->GetAttr(attr_name));
W
weishengying 已提交
68 69 70
      break;
    };
    case framework::proto::AttrType::BOOLEAN: {
71
      return PADDLE_GET_CONST(bool, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
72 73 74
      break;
    };
    case framework::proto::AttrType::BOOLEANS: {
75
      return PADDLE_GET_CONST(std::vector<bool>, op_desc_->GetAttr(attr_name));
W
weishengying 已提交
76 77 78 79 80 81 82 83 84 85 86
      break;
    };
    default: {
      LOG(ERROR) << "Can't conver op's attribute [" << attr_name
                 << "] to paddle any.";
    }
  }
  return paddle::any();
}

size_t PluginArgumentMappingContext::InputSize(const std::string& name) const {
87
  return op_desc_->Inputs().at(name).size();
W
weishengying 已提交
88
}
89

W
weishengying 已提交
90
size_t PluginArgumentMappingContext::OutputSize(const std::string& name) const {
91
  return op_desc_->Outputs().at(name).size();
W
weishengying 已提交
92
}
93

W
weishengying 已提交
94 95
bool PluginArgumentMappingContext::IsDenseTensorInput(
    const std::string& name) const {
96
  return true;
W
weishengying 已提交
97
}
98

W
weishengying 已提交
99 100
bool PluginArgumentMappingContext::IsDenseTensorInputs(
    const std::string& name) const {
101
  return true;
W
weishengying 已提交
102
}
103 104

bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
W
weishengying 已提交
105
    const std::string& name) const {
106 107
  PADDLE_THROW(phi::errors::Unimplemented(
      "Not supported for input vector of DenseTensor."));
W
weishengying 已提交
108 109
  return false;
}
110 111

bool PluginArgumentMappingContext::IsDenseTensorOutput(
Y
YuanRisheng 已提交
112
    const std::string& name) const {
113
  return true;
Y
YuanRisheng 已提交
114
}
115 116

bool PluginArgumentMappingContext::IsSelectedRowsInput(
117
    const std::string& name) const {
118 119
  PADDLE_THROW(
      phi::errors::Unimplemented("Not supported for input of SelectedRows."));
120 121
  return false;
}
122

123
bool PluginArgumentMappingContext::IsSelectedRowsInputs(
124
    const std::string& name) const {
125 126
  PADDLE_THROW(
      phi::errors::Unimplemented("Not supported for inputs of SelectedRows."));
127 128 129
  return false;
}

130
bool PluginArgumentMappingContext::IsSelectedRowsOutput(
131
    const std::string& name) const {
132 133
  PADDLE_THROW(
      phi::errors::Unimplemented("Not supported for output of SelectedRows."));
134 135
  return false;
}
136 137

bool PluginArgumentMappingContext::IsSparseCooTensorInput(
W
weishengying 已提交
138
    const std::string& name) const {
139 140
  PADDLE_THROW(phi::errors::Unimplemented(
      "Not supported for input of SparseCooTensor."));
W
weishengying 已提交
141 142 143
  return false;
}

144
bool PluginArgumentMappingContext::IsSparseCooTensorOutput(
W
weishengying 已提交
145
    const std::string& name) const {
146 147
  PADDLE_THROW(phi::errors::Unimplemented(
      "Not supported for output of SparseCooTensor."));
W
weishengying 已提交
148 149
  return false;
}
150 151

bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
W
weishengying 已提交
152
    const std::string& name) const {
153 154
  PADDLE_THROW(phi::errors::Unimplemented(
      "Not supported for input of SparseCsrTensor."));
W
weishengying 已提交
155 156
  return false;
}
157 158 159 160 161 162

bool PluginArgumentMappingContext::IsForInferShape() const {
  PADDLE_THROW(phi::errors::Unimplemented("Not supported for InferShape."));
  return false;
}

W
weishengying 已提交
163 164 165
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle