未验证 提交 fca8595e 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle-TRT] add generic plugin for lookup_table_v2(embedding) op (#53539)

* add embedding generic plugin, not enabled
上级 2bf61284
...@@ -17,13 +17,13 @@ limitations under the License. */ ...@@ -17,13 +17,13 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h" #include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class CustomPluginCreater : public OpConverter { class CustomPluginCreater : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc &op, void operator()(const framework::proto::OpDesc &op,
...@@ -164,6 +164,8 @@ class GenericPluginCreater : public OpConverter { ...@@ -164,6 +164,8 @@ class GenericPluginCreater : public OpConverter {
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override { bool test_mode) override {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert " << op_desc.Type() << " op to generic pluign layer";
CHECK(block_); CHECK(block_);
const framework::BlockDesc block_desc( const framework::BlockDesc block_desc(
nullptr, const_cast<framework::proto::BlockDesc *>(block_)); nullptr, const_cast<framework::proto::BlockDesc *>(block_));
...@@ -181,6 +183,14 @@ class GenericPluginCreater : public OpConverter { ...@@ -181,6 +183,14 @@ class GenericPluginCreater : public OpConverter {
phi_kernel_signature = phi_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type()); phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type());
} }
VLOG(3) << phi_kernel_signature;
PADDLE_ENFORCE_EQ(
phi_kernel_signature.input_names.empty() ||
phi_kernel_signature.output_names.empty(),
false,
platform::errors::PreconditionNotMet(
"The %s op's kernel signature (inputs and output) should not null.",
op_desc.Type()));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
...@@ -417,6 +417,24 @@ nvinfer1::DimsExprs GridSamplerInferMeta( ...@@ -417,6 +417,24 @@ nvinfer1::DimsExprs GridSamplerInferMeta(
return output; return output;
} }
nvinfer1::DimsExprs LookupTableV2InferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
const auto x_dims = inputs[0];
const auto weight_dims = inputs[1];
nvinfer1::DimsExprs output;
output.nbDims = x_dims.nbDims + 1;
for (int i = 0; i < x_dims.nbDims; ++i) {
output.d[i] = x_dims.d[i];
}
output.d[x_dims.nbDims] = weight_dims.d[1];
return output;
}
PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
...@@ -427,6 +445,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta); ...@@ -427,6 +445,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -2865,10 +2865,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2865,10 +2865,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"logsigmoid", "logsigmoid",
"preln_layernorm_shift_partition", "preln_layernorm_shift_partition",
"lookup_table", "lookup_table",
"lookup_table_v2",
"trans_layernorm", "trans_layernorm",
"merge_layernorm", "merge_layernorm",
"skip_merge_layernorm", "skip_merge_layernorm",
"lookup_table_v2",
"expand_v2", "expand_v2",
"expand_as_v2", "expand_as_v2",
"fuse_eleadd_transpose", "fuse_eleadd_transpose",
...@@ -3143,6 +3143,7 @@ OpTeller::OpTeller() { ...@@ -3143,6 +3143,7 @@ OpTeller::OpTeller() {
tellers_.emplace_back(new tensorrt::GenericPluginTeller); tellers_.emplace_back(new tensorrt::GenericPluginTeller);
tellers_.emplace_back(new tensorrt::CustomPluginTeller); tellers_.emplace_back(new tensorrt::CustomPluginTeller);
} }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h" #include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
...@@ -354,6 +355,19 @@ bool GenericPlugin::supportsFormatCombination( ...@@ -354,6 +355,19 @@ bool GenericPlugin::supportsFormatCombination(
if (pos == 3) if (pos == 3)
return in_out[0].type == in_out[pos].type && return in_out[0].type == in_out[pos].type &&
in_out[0].format == in_out[pos].format; in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "lookup_table_v2") {
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kINT32 &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR));
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) ||
((isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 2)
return in_out[1].type == in_out[pos].type &&
in_out[1].format == in_out[pos].format;
} else { } else {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT || return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() && (isFp16Supported() &&
...@@ -367,6 +381,9 @@ nvinfer1::DataType GenericPlugin::getOutputDataType( ...@@ -367,6 +381,9 @@ nvinfer1::DataType GenericPlugin::getOutputDataType(
int index, int index,
const nvinfer1::DataType* input_types, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT { int nb_inputs) const TRT_NOEXCEPT {
if (op_desc_.Type() == "lookup_table_v2") {
return input_types[1];
}
return input_types[0]; return input_types[0];
} }
...@@ -472,7 +489,7 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -472,7 +489,7 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
platform::CUDAPlace place(platform::GetCurrentDeviceId()); platform::CUDAPlace place(platform::GetCurrentDeviceId());
// [TODO]now generic plugin do not support INT8 precision // TODO(inference): generic plugin do not support INT8 precision now.
auto protoType2PhiType = auto protoType2PhiType =
[&](int proto_type, [&](int proto_type,
nvinfer1::DataType nv_dtype) -> std::pair<phi::DataType, int> { nvinfer1::DataType nv_dtype) -> std::pair<phi::DataType, int> {
...@@ -503,8 +520,13 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -503,8 +520,13 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
} }
}; };
nvinfer1::DataType data_type;
// input // input
auto data_type = input_desc[0].type; if (op_desc_.Type() == "lookup_table_v2") {
data_type = input_desc[1].type;
} else {
data_type = input_desc[0].type;
}
CHECK((data_type == nvinfer1::DataType::kFLOAT) || CHECK((data_type == nvinfer1::DataType::kFLOAT) ||
(data_type == nvinfer1::DataType::kHALF)); (data_type == nvinfer1::DataType::kHALF));
...@@ -555,8 +577,6 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -555,8 +577,6 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
output_numel * data_type_and_size.second, output_numel * data_type_and_size.second,
place)); place));
phi::DenseTensor output_densetonsor(output_alloc, output_meta);
(*dense_tensor_outputs_)[i] = (*dense_tensor_outputs_)[i] =
std::move(phi::DenseTensor(output_alloc, output_meta)); std::move(phi::DenseTensor(output_alloc, output_meta));
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
#pragma once #pragma once
#include <NvInfer.h> #include <NvInfer.h>
#include <stdio.h>
#include <cassert>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -84,45 +82,50 @@ class GenericPlugin : public DynamicPluginTensorRT { ...@@ -84,45 +82,50 @@ class GenericPlugin : public DynamicPluginTensorRT {
// Shutdown the layer. This is called when the engine is destroyed // Shutdown the layer. This is called when the engine is destroyed
void terminate() TRT_NOEXCEPT override; void terminate() TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT{}; void destroy() TRT_NOEXCEPT override{};
size_t getSerializationSize() const TRT_NOEXCEPT { size_t getSerializationSize() const TRT_NOEXCEPT override {
return op_meta_data_.size() + SerializedSize(inputs_data_type_) + size_t sum = 0;
SerializedSize(outputs_data_type_) + SerializedSize(with_fp16_); sum += SerializedSize(inputs_data_type_);
sum += SerializedSize(outputs_data_type_);
sum += SerializedSize(with_fp16_);
sum += op_meta_data_.size();
return sum;
} }
void serialize(void* buffer) const TRT_NOEXCEPT; void serialize(void* buffer) const TRT_NOEXCEPT override;
// The Func in IPluginV2 // The Func in IPluginV2
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT; nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int output_index, int output_index,
const nvinfer1::DimsExprs* inputs, const nvinfer1::DimsExprs* inputs,
int nb_inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT; TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out, const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT; int nb_outputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs, int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out, const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT; int nb_outputs) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* input_desc, int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, const void* const* inputs,
void* const* outputs, void* const* outputs,
void* workspace, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT; cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT; int nb_inputs) const
TRT_NOEXCEPT override;
bool isFp16Supported() { bool isFp16Supported() {
auto half_dtype = nvinfer1::DataType::kHALF; auto half_dtype = nvinfer1::DataType::kHALF;
...@@ -146,7 +149,6 @@ class GenericPlugin : public DynamicPluginTensorRT { ...@@ -146,7 +149,6 @@ class GenericPlugin : public DynamicPluginTensorRT {
std::vector<phi::DenseTensor>* dense_tensor_outputs_{nullptr}; std::vector<phi::DenseTensor>* dense_tensor_outputs_{nullptr};
private: private:
InputOutPutVarInfo in_out_info_;
std::vector<int> inputs_data_type_; std::vector<int> inputs_data_type_;
std::vector<int> outputs_data_type_; std::vector<int> outputs_data_type_;
}; };
...@@ -166,6 +168,7 @@ class GenericPluginCreator : public TensorRTPluginCreator { ...@@ -166,6 +168,7 @@ class GenericPluginCreator : public TensorRTPluginCreator {
return new GenericPlugin(serial_data, serial_length); return new GenericPlugin(serial_data, serial_length);
} }
}; };
REGISTER_TRT_PLUGIN_V2(GenericPluginCreator); REGISTER_TRT_PLUGIN_V2(GenericPluginCreator);
} // namespace plugin } // namespace plugin
......
...@@ -19,7 +19,7 @@ namespace inference { ...@@ -19,7 +19,7 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
bool PluginArgumentMappingContext::HasInput(const std::string& name) const { bool PluginArgumentMappingContext::HasInput(const std::string& name) const {
auto inputs = op_desc_ptr_->Inputs(); auto inputs = op_desc_->Inputs();
for (auto& i : inputs) { for (auto& i : inputs) {
if (i.first == name && !i.second.empty()) return true; if (i.first == name && !i.second.empty()) return true;
} }
...@@ -27,7 +27,7 @@ bool PluginArgumentMappingContext::HasInput(const std::string& name) const { ...@@ -27,7 +27,7 @@ bool PluginArgumentMappingContext::HasInput(const std::string& name) const {
} }
bool PluginArgumentMappingContext::HasOutput(const std::string& name) const { bool PluginArgumentMappingContext::HasOutput(const std::string& name) const {
auto outputs = op_desc_ptr_->Outputs(); auto outputs = op_desc_->Outputs();
for (auto& i : outputs) { for (auto& i : outputs) {
if (i.first == name && !i.second.empty()) return true; if (i.first == name && !i.second.empty()) return true;
} }
...@@ -35,47 +35,44 @@ bool PluginArgumentMappingContext::HasOutput(const std::string& name) const { ...@@ -35,47 +35,44 @@ bool PluginArgumentMappingContext::HasOutput(const std::string& name) const {
} }
bool PluginArgumentMappingContext::HasAttr(const std::string& name) const { bool PluginArgumentMappingContext::HasAttr(const std::string& name) const {
return op_desc_ptr_->HasAttr(name); return op_desc_->HasAttr(name);
} }
paddle::any PluginArgumentMappingContext::Attr( paddle::any PluginArgumentMappingContext::Attr(
const std::string& attr_name) const { const std::string& attr_name) const {
auto attr_type = op_desc_ptr_->GetAttrType(attr_name); auto attr_type = op_desc_->GetAttrType(attr_name);
switch (attr_type) { switch (attr_type) {
case framework::proto::AttrType::INT: { case framework::proto::AttrType::INT: {
return PADDLE_GET_CONST(int, op_desc_ptr_->GetAttr(attr_name)); return PADDLE_GET_CONST(int, op_desc_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::FLOAT: { case framework::proto::AttrType::FLOAT: {
return PADDLE_GET_CONST(float, op_desc_ptr_->GetAttr(attr_name)); return PADDLE_GET_CONST(float, op_desc_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::STRING: { case framework::proto::AttrType::STRING: {
return PADDLE_GET_CONST(std::string, op_desc_ptr_->GetAttr(attr_name)); return PADDLE_GET_CONST(std::string, op_desc_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::INTS: { case framework::proto::AttrType::INTS: {
return PADDLE_GET_CONST(std::vector<int>, return PADDLE_GET_CONST(std::vector<int>, op_desc_->GetAttr(attr_name));
op_desc_ptr_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::FLOATS: { case framework::proto::AttrType::FLOATS: {
return PADDLE_GET_CONST(std::vector<float>, return PADDLE_GET_CONST(std::vector<float>, op_desc_->GetAttr(attr_name));
op_desc_ptr_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::STRINGS: { case framework::proto::AttrType::STRINGS: {
return PADDLE_GET_CONST(std::vector<std::string>, return PADDLE_GET_CONST(std::vector<std::string>,
op_desc_ptr_->GetAttr(attr_name)); op_desc_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::BOOLEAN: { case framework::proto::AttrType::BOOLEAN: {
return PADDLE_GET_CONST(bool, op_desc_ptr_->GetAttr(attr_name)); return PADDLE_GET_CONST(bool, op_desc_->GetAttr(attr_name));
break; break;
}; };
case framework::proto::AttrType::BOOLEANS: { case framework::proto::AttrType::BOOLEANS: {
return PADDLE_GET_CONST(std::vector<bool>, return PADDLE_GET_CONST(std::vector<bool>, op_desc_->GetAttr(attr_name));
op_desc_ptr_->GetAttr(attr_name));
break; break;
}; };
default: { default: {
...@@ -87,54 +84,82 @@ paddle::any PluginArgumentMappingContext::Attr( ...@@ -87,54 +84,82 @@ paddle::any PluginArgumentMappingContext::Attr(
} }
size_t PluginArgumentMappingContext::InputSize(const std::string& name) const { size_t PluginArgumentMappingContext::InputSize(const std::string& name) const {
return op_desc_ptr_->Inputs().at(name).size(); return op_desc_->Inputs().at(name).size();
} }
size_t PluginArgumentMappingContext::OutputSize(const std::string& name) const { size_t PluginArgumentMappingContext::OutputSize(const std::string& name) const {
return op_desc_ptr_->Outputs().at(name).size(); return op_desc_->Outputs().at(name).size();
} }
bool PluginArgumentMappingContext::IsDenseTensorInput( bool PluginArgumentMappingContext::IsDenseTensorInput(
const std::string& name) const { const std::string& name) const {
return false; return true;
} }
bool PluginArgumentMappingContext::IsDenseTensorInputs( bool PluginArgumentMappingContext::IsDenseTensorInputs(
const std::string& name) const { const std::string& name) const {
return false; return true;
} }
bool PluginArgumentMappingContext::IsSelectedRowsInput(
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for input vector of DenseTensor."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsSelectedRowsInputs(
bool PluginArgumentMappingContext::IsDenseTensorOutput(
const std::string& name) const { const std::string& name) const {
return false; return true;
} }
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
bool PluginArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(
phi::errors::Unimplemented("Not supported for input of SelectedRows."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsSparseCooTensorOutput( bool PluginArgumentMappingContext::IsSelectedRowsInputs(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(
phi::errors::Unimplemented("Not supported for inputs of SelectedRows."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsSparseCsrTensorInput( bool PluginArgumentMappingContext::IsSelectedRowsOutput(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(
phi::errors::Unimplemented("Not supported for output of SelectedRows."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for input of SparseCooTensor."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsDenseTensorOutput( bool PluginArgumentMappingContext::IsSparseCooTensorOutput(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for output of SparseCooTensor."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsSelectedRowsOutput(
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
const std::string& name) const { const std::string& name) const {
PADDLE_THROW(phi::errors::Unimplemented(
"Not supported for input of SparseCsrTensor."));
return false; return false;
} }
bool PluginArgumentMappingContext::IsForInferShape() const {
PADDLE_THROW(phi::errors::Unimplemented("Not supported for InferShape."));
return false;
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -25,8 +25,8 @@ namespace tensorrt { ...@@ -25,8 +25,8 @@ namespace tensorrt {
class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext { class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
public: public:
explicit PluginArgumentMappingContext(framework::OpDesc* op_desc_ptr) explicit PluginArgumentMappingContext(const framework::OpDesc* op_desc)
: op_desc_ptr_(op_desc_ptr) {} : op_desc_(op_desc) {}
bool HasInput(const std::string& name) const override; bool HasInput(const std::string& name) const override;
...@@ -60,11 +60,12 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext { ...@@ -60,11 +60,12 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsSelectedRowsOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; } bool IsForInferShape() const override;
private: private:
framework::OpDesc* op_desc_ptr_; const framework::OpDesc* op_desc_;
}; };
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -117,15 +117,10 @@ TEST(ArgMappingContexTest, BasicFunction) { ...@@ -117,15 +117,10 @@ TEST(ArgMappingContexTest, BasicFunction) {
EXPECT_EQ(context.InputSize("X"), true); EXPECT_EQ(context.InputSize("X"), true);
EXPECT_EQ(context.OutputSize("Out"), true); EXPECT_EQ(context.OutputSize("Out"), true);
EXPECT_EQ(context.IsDenseTensorInput("X"), false); EXPECT_EQ(context.IsDenseTensorInput("X"), true);
EXPECT_EQ(context.IsDenseTensorInputs("X"), false); EXPECT_EQ(context.IsDenseTensorInputs("X"), true);
EXPECT_EQ(context.IsSelectedRowsInput("X"), false);
EXPECT_EQ(context.IsDenseTensorVectorInput("X"), false); EXPECT_EQ(context.IsDenseTensorOutput("Out"), true);
EXPECT_EQ(context.IsDenseTensorOutput("Out"), false);
EXPECT_EQ(context.IsSelectedRowsOutput("Out"), false);
EXPECT_EQ(context.IsSparseCooTensorOutput("Out"), false);
EXPECT_EQ(context.IsForInferShape(), false);
} }
} // namespace tensorrt } // namespace tensorrt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册