提交 1622cb99 编写于 作者: H hjchen2

Fix alpha tensor key

上级 a8c077df
...@@ -72,7 +72,10 @@ class LeakyReluOpConverter : public OpConverter { ...@@ -72,7 +72,10 @@ class LeakyReluOpConverter : public OpConverter {
nvinfer1::ElementWiseOperation::kSUM); nvinfer1::ElementWiseOperation::kSUM);
PADDLE_ENFORCE(nullptr != output_layer); PADDLE_ENFORCE(nullptr != output_layer);
// keep alpha tensor to avoid release it's memory // keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("alpha")[0]] = std::move(alpha_tensor); std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) ==
engine_->weight_map.end());
engine_->weight_map[alpha_name] = std::move(alpha_tensor);
std::string layer_name = "leaky_relu (Output: "; std::string layer_name = "leaky_relu (Output: ";
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
......
...@@ -20,8 +20,8 @@ namespace paddle { ...@@ -20,8 +20,8 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(leaky_relu_op, test_channel_wise) { TEST(leaky_relu_op, test_leaky_relu) {
std::unordered_set<std::string> parameters({"leaky_relu_alpha"}); std::unordered_set<std::string> parameters;
framework::Scope scope; framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000); TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2)); validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册