未验证 提交 c7613b67 编写于 作者: K Kevin Jones 提交者: GitHub

Refactor Rfc2898DeriveBytes to support spans

上级 6cde4bed
...@@ -65,14 +65,12 @@ internal static partial class Pbkdf2Implementation ...@@ -65,14 +65,12 @@ internal static partial class Pbkdf2Implementation
Span<byte> destination) Span<byte> destination)
{ {
using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes( using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(
password.ToArray(), password,
salt.ToArray(), salt,
iterations, iterations,
hashAlgorithmName, hashAlgorithmName))
clearPassword: true))
{ {
byte[] result = deriveBytes.GetBytes(destination.Length); deriveBytes.GetBytes(destination);
result.AsSpan().CopyTo(destination);
} }
} }
} }
......
...@@ -19,14 +19,12 @@ internal static partial class Pbkdf2Implementation ...@@ -19,14 +19,12 @@ internal static partial class Pbkdf2Implementation
Debug.Assert(hashAlgorithmName.Name is not null); Debug.Assert(hashAlgorithmName.Name is not null);
using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes( using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(
password.ToArray(), password,
salt.ToArray(), salt,
iterations, iterations,
hashAlgorithmName, hashAlgorithmName))
clearPassword: true))
{ {
byte[] result = deriveBytes.GetBytes(destination.Length); deriveBytes.GetBytes(destination);
result.AsSpan().CopyTo(destination);
} }
} }
} }
......
...@@ -16,7 +16,7 @@ public partial class Rfc2898DeriveBytes : DeriveBytes ...@@ -16,7 +16,7 @@ public partial class Rfc2898DeriveBytes : DeriveBytes
{ {
private byte[] _salt; private byte[] _salt;
private uint _iterations; private uint _iterations;
private HMAC _hmac; private IncrementalHash _hmac;
private readonly int _blockSize; private readonly int _blockSize;
private byte[] _buffer; private byte[] _buffer;
...@@ -84,34 +84,36 @@ public Rfc2898DeriveBytes(string password, int saltSize, int iterations, HashAlg ...@@ -84,34 +84,36 @@ public Rfc2898DeriveBytes(string password, int saltSize, int iterations, HashAlg
HashAlgorithm = hashAlgorithm; HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac(passwordBytes); _hmac = OpenHmac(passwordBytes);
CryptographicOperations.ZeroMemory(passwordBytes); CryptographicOperations.ZeroMemory(passwordBytes);
// _blockSize is in bytes, HashSize is in bits. _blockSize = _hmac.HashLengthInBytes;
_blockSize = _hmac.HashSize >> 3;
Initialize(); Initialize();
} }
internal Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm, bool clearPassword) internal Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm, bool clearPassword) :
this(
new ReadOnlySpan<byte>(password ?? throw new NullReferenceException()), // This "should" be ArgumentNullException but for compat, we throw NullReferenceException.
new ReadOnlySpan<byte>(salt ?? throw new ArgumentNullException(nameof(salt))),
iterations,
hashAlgorithm)
{ {
ArgumentNullException.ThrowIfNull(salt); if (clearPassword)
{
CryptographicOperations.ZeroMemory(password);
}
}
internal Rfc2898DeriveBytes(ReadOnlySpan<byte> password, ReadOnlySpan<byte> salt, int iterations, HashAlgorithmName hashAlgorithm)
{
if (iterations <= 0) if (iterations <= 0)
throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum); throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum);
if (password is null)
throw new NullReferenceException(); // This "should" be ArgumentNullException but for compat, we throw NullReferenceException.
_salt = new byte[salt.Length + sizeof(uint)]; _salt = new byte[salt.Length + sizeof(uint)];
salt.AsSpan().CopyTo(_salt); salt.CopyTo(_salt);
_iterations = (uint)iterations; _iterations = (uint)iterations;
HashAlgorithm = hashAlgorithm; HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac(password); _hmac = OpenHmac(password);
if (clearPassword) _blockSize = _hmac.HashLengthInBytes;
{
CryptographicOperations.ZeroMemory(password);
}
// _blockSize is in bytes, HashSize is in bits.
_blockSize = _hmac.HashSize >> 3;
Initialize(); Initialize();
} }
...@@ -167,27 +169,35 @@ protected override void Dispose(bool disposing) ...@@ -167,27 +169,35 @@ protected override void Dispose(bool disposing)
public override byte[] GetBytes(int cb) public override byte[] GetBytes(int cb)
{ {
Debug.Assert(_blockSize > 0);
if (cb <= 0) if (cb <= 0)
throw new ArgumentOutOfRangeException(nameof(cb), SR.ArgumentOutOfRange_NeedPosNum); throw new ArgumentOutOfRangeException(nameof(cb), SR.ArgumentOutOfRange_NeedPosNum);
byte[] password = new byte[cb];
byte[] ret = new byte[cb];
GetBytes(ret);
return ret;
}
internal void GetBytes(Span<byte> destination)
{
Debug.Assert(_blockSize > 0);
int cb = destination.Length;
int offset = 0; int offset = 0;
int size = _endIndex - _startIndex; int size = _endIndex - _startIndex;
ReadOnlySpan<byte> bufferSpan = _buffer;
if (size > 0) if (size > 0)
{ {
if (cb >= size) if (cb >= size)
{ {
Buffer.BlockCopy(_buffer, _startIndex, password, 0, size); bufferSpan.Slice(_startIndex, size).CopyTo(destination);
_startIndex = _endIndex = 0; _startIndex = _endIndex = 0;
offset += size; offset += size;
} }
else else
{ {
Buffer.BlockCopy(_buffer, _startIndex, password, 0, cb); bufferSpan.Slice(_startIndex, cb).CopyTo(destination);
_startIndex += cb; _startIndex += cb;
return password; return;
} }
} }
...@@ -199,18 +209,17 @@ public override byte[] GetBytes(int cb) ...@@ -199,18 +209,17 @@ public override byte[] GetBytes(int cb)
int remainder = cb - offset; int remainder = cb - offset;
if (remainder >= _blockSize) if (remainder >= _blockSize)
{ {
Buffer.BlockCopy(_buffer, 0, password, offset, _blockSize); bufferSpan.Slice(0, _blockSize).CopyTo(destination.Slice(offset));
offset += _blockSize; offset += _blockSize;
} }
else else
{ {
Buffer.BlockCopy(_buffer, 0, password, offset, remainder); bufferSpan.Slice(0, remainder).CopyTo(destination.Slice(offset));
_startIndex = remainder; _startIndex = remainder;
_endIndex = _buffer.Length; _endIndex = _buffer.Length;
return password; return;
} }
} }
return password;
} }
[Obsolete(Obsoletions.Rfc2898CryptDeriveKeyMessage, DiagnosticId = Obsoletions.Rfc2898CryptDeriveKeyDiagId, UrlFormat = Obsoletions.SharedUrlFormat)] [Obsolete(Obsoletions.Rfc2898CryptDeriveKeyMessage, DiagnosticId = Obsoletions.Rfc2898CryptDeriveKeyDiagId, UrlFormat = Obsoletions.SharedUrlFormat)]
...@@ -230,26 +239,25 @@ public override void Reset() ...@@ -230,26 +239,25 @@ public override void Reset()
Initialize(); Initialize();
} }
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA5350", Justification = "HMACSHA1 is needed for compat. (https://github.com/dotnet/runtime/issues/17618)")] private IncrementalHash OpenHmac(ReadOnlySpan<byte> password)
private HMAC OpenHmac(byte[] password)
{ {
Debug.Assert(password != null);
HashAlgorithmName hashAlgorithm = HashAlgorithm; HashAlgorithmName hashAlgorithm = HashAlgorithm;
if (string.IsNullOrEmpty(hashAlgorithm.Name)) if (string.IsNullOrEmpty(hashAlgorithm.Name))
{
throw new CryptographicException(SR.Cryptography_HashAlgorithmNameNullOrEmpty); throw new CryptographicException(SR.Cryptography_HashAlgorithmNameNullOrEmpty);
}
if (hashAlgorithm == HashAlgorithmName.SHA1) // Restrict the HashAlgorithmName to known hashes, particularly excluding MD5.
return new HMACSHA1(password); if (hashAlgorithm != HashAlgorithmName.SHA1 &&
if (hashAlgorithm == HashAlgorithmName.SHA256) hashAlgorithm != HashAlgorithmName.SHA256 &&
return new HMACSHA256(password); hashAlgorithm != HashAlgorithmName.SHA384 &&
if (hashAlgorithm == HashAlgorithmName.SHA384) hashAlgorithm != HashAlgorithmName.SHA512)
return new HMACSHA384(password); {
if (hashAlgorithm == HashAlgorithmName.SHA512) throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithm.Name));
return new HMACSHA512(password); }
throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithm.Name)); return IncrementalHash.CreateHMAC(hashAlgorithm, password);
} }
[MemberNotNull(nameof(_buffer))] [MemberNotNull(nameof(_buffer))]
...@@ -281,20 +289,17 @@ private void Func() ...@@ -281,20 +289,17 @@ private void Func()
// //
Span<byte> uiSpan = stackalloc byte[64]; Span<byte> uiSpan = stackalloc byte[64];
uiSpan = uiSpan.Slice(0, _blockSize); uiSpan = uiSpan.Slice(0, _blockSize);
_hmac.AppendData(_salt);
if (!_hmac.TryComputeHash(_salt, uiSpan, out int bytesWritten) || bytesWritten != _blockSize) int bytesWritten = _hmac.GetHashAndReset(uiSpan);
{ Debug.Assert(bytesWritten == _blockSize);
throw new CryptographicException();
}
uiSpan.CopyTo(_buffer); uiSpan.CopyTo(_buffer);
for (int i = 2; i <= _iterations; i++) for (int i = 2; i <= _iterations; i++)
{ {
if (!_hmac.TryComputeHash(uiSpan, uiSpan, out bytesWritten) || bytesWritten != _blockSize) _hmac.AppendData(uiSpan);
{ bytesWritten = _hmac.GetHashAndReset(uiSpan);
throw new CryptographicException(); Debug.Assert(bytesWritten == _blockSize);
}
for (int j = _buffer.Length - 1; j >= 0; j--) for (int j = _buffer.Length - 1; j >= 0; j--)
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册