提交 1622cb99 编写于 作者: H hjchen2

Fix alpha tensor key

上级 a8c077df
......@@ -72,7 +72,10 @@ class LeakyReluOpConverter : public OpConverter {
nvinfer1::ElementWiseOperation::kSUM);
PADDLE_ENFORCE(nullptr != output_layer);
// keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("alpha")[0]] = std::move(alpha_tensor);
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) ==
engine_->weight_map.end());
engine_->weight_map[alpha_name] = std::move(alpha_tensor);
std::string layer_name = "leaky_relu (Output: ";
auto output_name = op_desc.Output("Out")[0];
......
......@@ -20,8 +20,8 @@ namespace paddle {
namespace inference {
namespace tensorrt {
TEST(leaky_relu_op, test_channel_wise) {
std::unordered_set<std::string> parameters({"leaky_relu_alpha"});
TEST(leaky_relu_op, test_leaky_relu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册