提交 dc67c979 编写于 作者: T TensorFlower Gardener

Merge pull request #24275 from trevor-m:tmorris_tftrt_validators_conv

PiperOrigin-RevId: 225253270
......@@ -1533,6 +1533,24 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV };
tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.at(0).is_weights()) {
return tensorflow::errors::Unimplemented(
node_def.op(), " is only implemented for tensors, not weights, at ",
node_def.name());
}
if (inputs.at(1).is_tensor()) {
return tensorflow::errors::Unimplemented("Kernel for ", node_def.op(),
" must be constant weights, at ",
node_def.name());
}
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
VLOG(2) << "weight shape: " << weights_rsck.DebugString();
if (weights_rsck.shape_.nbDims != 4) {
return tensorflow::errors::Internal(
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
}
if (params->validation_only) return tensorflow::Status::OK();
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
TFAttrs attrs(node_def);
......@@ -1554,12 +1572,6 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
VLOG(2) << "weight shape: " << weights_rsck.DebugString();
if (weights_rsck.shape_.nbDims != 4) {
return tensorflow::errors::Internal(
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
}
if (params->converter->precision_mode() == FP16MODE) {
weights_rsck =
ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights());
......@@ -1646,7 +1658,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params,
case ConvolutionType::DEPTHWISE_CONV:
return ConvertConv2DHelper(params, 0);
}
return tensorflow::errors::Unimplemented("unsupported convolution type at, " +
return tensorflow::errors::Unimplemented("Unsupported convolution type, at ",
params->node_def.name());
}
......@@ -2027,9 +2039,29 @@ tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) {
tensorflow::Status ConvertPool(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
if (inputs.at(0).is_weights()) {
return tensorflow::errors::Unimplemented(
node_def.op(), " is only implemented for tensors, not weights, at ",
node_def.name());
}
nvinfer1::PoolingType type;
if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
} else if (node_def.op() == "AvgPool") {
type = nvinfer1::PoolingType::kAVERAGE;
} else {
return tensorflow::errors::Unimplemented(
"Unsupported pooling type: ", node_def.op(), ", at ", node_def.name());
}
TFAttrs attrs(node_def);
const string padding_type = attrs.get<string>("padding");
if ((padding_type != "SAME") && (padding_type != "VALID")) {
return tensorflow::errors::Unimplemented(
"Unsupported padding type: ", padding_type, ", at ", node_def.name());
}
if (params->validation_only) return Status::OK();
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
int h_index = 2;
int w_index = 3;
const auto data_format = attrs.get<string>("data_format");
......@@ -2040,16 +2072,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) {
const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
}
nvinfer1::PoolingType type;
if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
} else if (node_def.op() == "AvgPool") {
type = nvinfer1::PoolingType::kAVERAGE;
} else {
return tensorflow::errors::Unimplemented("Unsupported pool type: ",
node_def.op());
}
const auto tf_stride = attrs.get<std::vector<int>>("strides");
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
......@@ -2058,7 +2080,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) {
auto tensor_dim = tensor->getDimensions();
std::vector<std::pair<int, int>> padding;
const string padding_type = attrs.get<string>("padding");
if (padding_type == "SAME") {
// This is NCHW tensor with no batch dimension.
// 1 -> h
......@@ -2068,9 +2089,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) {
{static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
} else if (padding_type == "VALID") {
padding = {{0, 0}, {0, 0}};
} else {
return tensorflow::errors::Unimplemented("Unsupported padding type: ",
padding_type);
}
if (padding[0].first != padding[0].second ||
......@@ -2837,6 +2855,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) {
return tensorflow::errors::Unimplemented(
"Padding layer does not support padding on dimension 1 and 3 yet");
}
if (params->validation_only) return Status::OK();
bool legit_pad = true;
nvinfer1::DimsHW pre_padding(0, 0);
......@@ -2940,6 +2959,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) {
inputs_vec.push_back(tensor_i);
}
if (params->validation_only) return tensorflow::Status::OK();
// nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
nvinfer1::IConcatenationLayer* layer =
......@@ -2961,12 +2981,35 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) {
auto data_format = attrs.get<string>("data_format");
if (data_format != "NCHW") {
return tensorflow::errors::Unimplemented(
"only data_format=NCHW is supported, at " + node_def.name());
node_def.op(), " only supports data_format=NCHW, at ", node_def.name());
}
bool is_training = attrs.get<bool>("is_training");
if (is_training) {
// Trying to use batchnorm in training mode is a very common problem.
// Because the error message will only be printed in VLOG(1) by the
// segmenter, we issue a special warning so that users will actually see it.
LOG(WARNING) << node_def.op() << " only supports is_training=false. If you "
<< "are using Keras, please call "
<< "keras.backend.set_learning_phase(0) before constructing "
<< "your model. At " << node_def.name();
return tensorflow::errors::Unimplemented(
"only is_training=false is supported, at " + node_def.name());
node_def.op(), " only supports is_training=false, at ",
node_def.name());
}
if (inputs.at(0).is_weights()) {
return tensorflow::errors::Unimplemented(
node_def.op(),
" is only implemented for tensor inputs, not weights, at ",
node_def.name());
}
for (int i = 1; i < 5; i++) {
if (inputs.at(i).is_tensor()) {
return tensorflow::errors::Unimplemented(
node_def.op(),
" must have constant inputs for scale, offset, mean and variance, "
"at ",
node_def.name());
}
}
nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
......@@ -2981,7 +3024,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) {
for (int i = 1; i < 5; i++) {
if (inputs.at(i).weights().type_ != parameter_type) {
return tensorflow::errors::Unimplemented(
"Inconsistent parameter type for batchnormis not supported, at: " +
"Inconsistent parameter type for batchnorm is not supported, at: " +
node_def.name());
}
}
......@@ -3001,6 +3044,8 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) {
"Inconsistent batchnorm parameter count, at: " + node_def.name());
}
}
if (params->validation_only) return Status::OK();
// We could technically have two weights with different shape.
// that requires two addScale op, arguably less performant
TRT_ShapedWeights combined_scale_weights =
......@@ -3286,10 +3331,14 @@ static void RegisterValidatableOpConverters(
std::unordered_map<string, OpConverter>* registration) {
// TODO(laigd): support all op types.
(*registration)["BiasAdd"] = ConvertBiasAdd;
(*registration)["ConcatV2"] = ConvertConcat;
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["Transpose"] = ConvertTranspose;
(*registration)["Reshape"] = ConvertReshape;
(*registration)["MatMul"] = ConvertMatMul;
(*registration)["Pad"] = ConvertPad;
(*registration)["Relu6"] = ConvertRelu6;
(*registration)["Square"] = ConvertSquare;
(*registration)["ExpandDims"] = ConvertExpandDims;
......@@ -3307,6 +3356,12 @@ static void RegisterValidatableOpConverters(
for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) {
(*registration)[activation_op_type] = ConvertActivation;
}
for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
(*registration)[pool_op_type] = ConvertPool;
}
for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) {
(*registration)[normalization_op_type] = ConvertFusedBatchNorm;
}
}
void TrtNodeValidator::RegisterOpValidators() {
......@@ -3315,21 +3370,10 @@ void TrtNodeValidator::RegisterOpValidators() {
void Converter::RegisterOpConverters() {
RegisterValidatableOpConverters(&op_registry_);
op_registry_["Conv2D"] = ConvertConv2D;
op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
op_registry_["MaxPool"] = ConvertPool;
op_registry_["AvgPool"] = ConvertPool;
// TODO(ben,jie): this is a temp hack.
op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed
op_registry_["Pad"] = ConvertPad;
op_registry_["ConcatV2"] = ConvertConcat;
op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm;
op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm;
op_registry_["Rsqrt"] = ConvertUnary;
op_registry_["Reciprocal"] = ConvertUnary;
op_registry_["Exp"] = ConvertUnary;
......
......@@ -191,7 +191,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
batch_size=batch_size,
num_parallel_calls=8))
dataset = dataset.repeat(count=1)
iterator = data.make_one_shot_iterator(dataset)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
......@@ -205,7 +205,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
batch_size=batch_size,
num_parallel_calls=8))
dataset = dataset.repeat(count=num_epochs)
iterator = data.make_one_shot_iterator(dataset)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册