未验证 提交 eb116336 编写于 作者: P Pei Yang 提交者: GitHub

batch_norm trt converter error message, test=develop (#23620)

上级 c1187cd6
......@@ -49,7 +49,6 @@ class ActivationOpConverter : public OpConverter {
layer->setAlpha(0.);
layer->setBeta(6.);
}
g
#endif
auto output_name = op_desc.Output("Out")[0];
......
......@@ -26,13 +26,37 @@ class BatchNormOpConverter : public OpConverter {
VLOG(3) << "convert a fluid batch norm op to tensorrt batch_norm";
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(),
1); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1,
platform::errors::InvalidArgument(
"Invalid input X's size of batch_norm TRT converter. "
"Expected 1, received %d.",
op_desc.Input("X").size()));
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Bias's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Input("Bias").size())); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Mean's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Input("Mean").size())); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Scale's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Input("Scale").size())); // Scale is a weight
PADDLE_ENFORCE_EQ(
op_desc.Input("Variance").size(), 1,
platform::errors::InvalidArgument(
"Invalid input Variance's size of batch_norm TRT converter. "
"Expected 1, received %d.",
op_desc.Input("Variance").size())); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1,
platform::errors::InvalidArgument(
"Invalid output Y's size of batch_norm TRT "
"converter. Expected 1, received %d.",
op_desc.Output("Y").size()));
auto* X = engine_->GetITensor(op_desc.Input("X").front());
// Declare weights
......@@ -42,10 +66,22 @@ class BatchNormOpConverter : public OpConverter {
auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front());
const float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
PADDLE_ENFORCE_NOT_NULL(Bias_v);
PADDLE_ENFORCE_NOT_NULL(Mean_v);
PADDLE_ENFORCE_NOT_NULL(Scale_v);
PADDLE_ENFORCE_NOT_NULL(Variance_v);
PADDLE_ENFORCE_NOT_NULL(
Bias_v,
platform::errors::NotFound(
"Variable of Bias of batch_norm TRT converter is not found."));
PADDLE_ENFORCE_NOT_NULL(
Mean_v,
platform::errors::NotFound(
"Variable of Mean of batch_norm TRT converter is not found."));
PADDLE_ENFORCE_NOT_NULL(
Scale_v,
platform::errors::NotFound(
"Variable of Scale of batch_norm TRT converter is not found."));
PADDLE_ENFORCE_NOT_NULL(
Variance_v,
platform::errors::NotFound(
"Variable of Variance of batch_norm TRT converter is not found."));
// get tensor
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册