提交 573812d0 编写于 作者: A Alexander Alekhin

Merge pull request #19373 from l-bat:lb/tf_matmul_shared

......@@ -1228,8 +1228,18 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
int kernel_blob_index = -1;
const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index);
blobFromTensor(kernelTensor, layerParams.blobs[0]);
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
const String kernelTensorName = layer.input(kernel_blob_index);
std::map<String, Mat>::iterator sharedWeightsIt = sharedWeights.find(kernelTensorName);
if (sharedWeightsIt == sharedWeights.end())
{
blobFromTensor(kernelTensor, layerParams.blobs[0]);
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
sharedWeights[kernelTensorName] = layerParams.blobs[0];
}
else
{
layerParams.blobs[0] = sharedWeightsIt->second;
}
if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed
Mat data = layerParams.blobs[0].t();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册