未验证 提交 eb116336 编写于 作者: P Pei Yang 提交者: GitHub

batch_norm trt converter error message, test=develop (#23620)

上级 c1187cd6
...@@ -49,10 +49,9 @@ class ActivationOpConverter : public OpConverter { ...@@ -49,10 +49,9 @@ class ActivationOpConverter : public OpConverter {
layer->setAlpha(0.); layer->setAlpha(0.);
layer->setBeta(6.); layer->setBeta(6.);
} }
g
#endif #endif
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode); RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
if (op_desc.HasAttr("out_scale")) { if (op_desc.HasAttr("out_scale")) {
......
...@@ -26,13 +26,37 @@ class BatchNormOpConverter : public OpConverter { ...@@ -26,13 +26,37 @@ class BatchNormOpConverter : public OpConverter {
VLOG(3) << "convert a fluid batch norm op to tensorrt batch_norm"; VLOG(3) << "convert a fluid batch norm op to tensorrt batch_norm";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1,
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight "Invalid input X's size of batch_norm TRT converter. "
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight "Expected 1, received %d.",
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(), op_desc.Input("X").size()));
1); // Variance is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1,
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1); platform::errors::InvalidArgument(
"Invalid input Bias's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Input("Bias").size())); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Mean's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Input("Mean").size())); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Scale's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Input("Scale").size())); // Scale is a weight
PADDLE_ENFORCE_EQ(
op_desc.Input("Variance").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Variance's size of batch_norm TRT converter. "
"Expected 1, received %d.",
op_desc.Input("Variance").size())); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1,
platform::errors::InvalidArgument(
"Invalid output Y's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Output("Y").size()));
auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* X = engine_->GetITensor(op_desc.Input("X").front());
// Declare weights // Declare weights
...@@ -42,10 +66,22 @@ class BatchNormOpConverter : public OpConverter { ...@@ -42,10 +66,22 @@ class BatchNormOpConverter : public OpConverter {
auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front()); auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front());
const float eps = boost::get<float>(op_desc.GetAttr("epsilon")); const float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
PADDLE_ENFORCE_NOT_NULL(Bias_v); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(Mean_v); Bias_v,
PADDLE_ENFORCE_NOT_NULL(Scale_v); platform::errors::NotFound(
PADDLE_ENFORCE_NOT_NULL(Variance_v); "Variable of Bias of batch_norm TRT converter is not found."));
PADDLE_ENFORCE_NOT_NULL(
Mean_v,
platform::errors::NotFound(
"Variable of Mean of batch_norm TRT converter is not found."));
PADDLE_ENFORCE_NOT_NULL(
Scale_v,
platform::errors::NotFound(
"Variable of Scale of batch_norm TRT converter is not found."));
PADDLE_ENFORCE_NOT_NULL(
Variance_v,
platform::errors::NotFound(
"Variable of Variance of batch_norm TRT converter is not found."));
// get tensor // get tensor
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>(); auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册