未验证 提交 266abfb3 编写于 作者: D Dejan Mijić 提交者: GitHub

Remote UTF8 decoder (#796)

上级 a993b0b5
......@@ -22,6 +22,8 @@ import zio.redis.options.Cluster.{Node, Partition, SlotRange}
import zio.schema.Schema
import zio.schema.codec.BinaryCodec
import java.nio.charset.StandardCharsets
sealed trait Output[+A] { self =>
protected def tryDecode(respValue: RespValue): A
......@@ -718,7 +720,7 @@ object Output {
}
private def decodeDouble(bytes: Chunk[Byte]): Double = {
val text = RespValue.decode(bytes)
val text = new String(bytes.toArray, StandardCharsets.UTF_8)
try text.toDouble
catch {
case _: NumberFormatException => throw ProtocolError(s"'$text' isn't a double.")
......
......@@ -52,7 +52,7 @@ final class SingleNodeExecutor private (
while (it.hasNext) {
val req = it.next()
buffer ++= RespValue.Array(req.command).serialize
buffer ++= RespValue.Array(req.command).asBytes
}
val bytes = buffer.result()
......
......@@ -27,7 +27,7 @@ private[redis] sealed trait RespValue extends Product with Serializable { self =
import RespValue._
import RespValue.internal.{CrLf, Headers, NullArrayEncoded, NullStringEncoded}
final def serialize: Chunk[Byte] =
final def asBytes: Chunk[Byte] =
self match {
case NullBulkString => NullStringEncoded
case NullArray => NullArrayEncoded
......@@ -39,7 +39,7 @@ private[redis] sealed trait RespValue extends Product with Serializable { self =
Headers.BulkString +: (encode(bytes.length.toString) ++ bytes ++ CrLf)
case Array(elements) =>
val data = elements.foldLeft[Chunk[Byte]](Chunk.empty)(_ ++ _.serialize)
val data = elements.foldLeft[Chunk[Byte]](Chunk.empty)(_ ++ _.asBytes)
Headers.Array +: (encode(elements.size.toString) ++ data)
}
......@@ -72,9 +72,9 @@ private[redis] object RespValue {
final case class Integer(value: Long) extends RespValue
final case class BulkString(value: Chunk[Byte]) extends RespValue {
def asLong: Long = internal.unsafeReadLong(asString, 0)
def asLong: Long = internal.unsafeReadLong(value, 0)
def asString: String = decode(value)
def asString: String = internal.decode(value)
}
final case class Array(values: Chunk[RespValue]) extends RespValue
......@@ -96,24 +96,20 @@ private[redis] object RespValue {
// ZSink fold will return a State.Start when contFn is false
val lineProcessor =
ZSink.fold[String, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
ZSink.foldChunks[Byte, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
case State.Done(value) => ZIO.succeed(Some(value))
case State.Failed => ZIO.fail(RedisError.ProtocolError("Invalid data received."))
case State.Start => ZIO.succeed(None)
case other => ZIO.dieMessage(s"Deserialization bug, should not get $other")
}
(ZPipeline.utf8Decode >>> ZPipeline.splitOn(internal.CrLfString))
.mapError(e => RedisError.ProtocolError(e.getLocalizedMessage))
.andThen(ZPipeline.fromSink(lineProcessor))
ZPipeline.splitOnChunk(internal.CrLf) >>> ZPipeline.fromSink(lineProcessor)
}
def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))
def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))
def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)
private object internal {
object Headers {
final val SimpleString: Byte = '+'
......@@ -124,11 +120,10 @@ private[redis] object RespValue {
}
final val CrLf: Chunk[Byte] = Chunk('\r', '\n')
final val CrLfString: String = "\r\n"
final val NullArrayEncoded: Chunk[Byte] = Chunk.fromArray("*-1\r\n".getBytes(StandardCharsets.US_ASCII))
final val NullArrayPrefix: String = "*-1"
final val NullStringEncoded: Chunk[Byte] = Chunk.fromArray("$-1\r\n".getBytes(StandardCharsets.US_ASCII))
final val NullStringPrefix: String = "$-1"
final val NullArrayEncoded: Chunk[Byte] = Chunk('*', '-', '1', '\r', '\n')
final val NullArrayPrefix: Chunk[Byte] = Chunk('*', '-', '1')
final val NullStringEncoded: Chunk[Byte] = Chunk('$', '-', '1', '\r', '\n')
final val NullStringPrefix: Chunk[Byte] = Chunk('$', '-', '1')
sealed trait State { self =>
import State._
......@@ -139,22 +134,22 @@ private[redis] object RespValue {
case _ => true
}
final def feed(line: String): State =
final def feed(bytes: Chunk[Byte]): State =
self match {
case Start if line.isEmpty() => Start
case Start if line == NullStringPrefix => Done(NullBulkString)
case Start if line == NullArrayPrefix => Done(NullArray)
case Start if line.nonEmpty =>
line.head match {
case Headers.SimpleString => Done(SimpleString(line.tail))
case Headers.Error => Done(Error(line.tail))
case Headers.Integer => Done(Integer(unsafeReadLong(line, 1)))
case Start if bytes.isEmpty => Start
case Start if bytes == NullStringPrefix => Done(NullBulkString)
case Start if bytes == NullArrayPrefix => Done(NullArray)
case Start if bytes.nonEmpty =>
bytes.head match {
case Headers.SimpleString => Done(SimpleString(decode(bytes.tail)))
case Headers.Error => Done(Error(decode(bytes.tail)))
case Headers.Integer => Done(Integer(unsafeReadLong(bytes, 1)))
case Headers.BulkString =>
val size = unsafeReadLong(line, 1).toInt
CollectingBulkString(size, new StringBuilder(size))
val size = unsafeReadSize(bytes)
CollectingBulkString(size, ChunkBuilder.make(size))
case Headers.Array =>
val size = unsafeReadLong(line, 1).toInt
val size = unsafeReadSize(bytes)
if (size > 0)
CollectingArray(size, ChunkBuilder.make(size), Start.feed)
......@@ -165,18 +160,20 @@ private[redis] object RespValue {
}
case CollectingArray(rem, vals, next) =>
next(line) match {
next(bytes) match {
case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed)
case Done(v) => Done(Array((vals += v).result()))
case state => CollectingArray(rem, vals, state.feed)
}
case CollectingBulkString(rem, vals) =>
if (line.length >= rem) {
val stringValue = vals.append(line.substring(0, rem)).toString
Done(BulkString(Chunk.fromArray(stringValue.getBytes(StandardCharsets.UTF_8))))
if (bytes.length >= rem) {
vals ++= bytes.take(rem)
Done(BulkString(vals.result()))
} else {
CollectingBulkString(rem - line.length - 2, vals.append(line).append(CrLfString))
vals ++= bytes
vals ++= CrLf
CollectingBulkString(rem - bytes.length - 2, vals)
}
case _ => Failed
......@@ -184,31 +181,50 @@ private[redis] object RespValue {
}
object State {
case object Start extends State
case object Failed extends State
final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: String => State) extends State
final case class CollectingBulkString(rem: Int, vals: StringBuilder) extends State
final case class Done(value: RespValue) extends State
case object Start extends State
case object Failed extends State
final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: Chunk[Byte] => State)
extends State
final case class CollectingBulkString(rem: Int, vals: ChunkBuilder[Byte]) extends State
final case class Done(value: RespValue) extends State
}
def unsafeReadLong(text: String, startFrom: Int): Long = {
def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)
def unsafeReadLong(bytes: Chunk[Byte], startFrom: Int): Long = {
var pos = startFrom
var res = 0L
var neg = false
if (text.charAt(pos) == '-') {
if (bytes(pos) == '-') {
neg = true
pos += 1
}
val len = text.length
val len = bytes.length
while (pos < len) {
res = res * 10 + text.charAt(pos) - '0'
res = res * 10 + bytes(pos) - '0'
pos += 1
}
if (neg) -res else res
}
def unsafeReadSize(bytes: Chunk[Byte]): Int = {
var pos = 1
var res = 0
val len = bytes.length
while (pos < len) {
res = res * 10 + bytes(pos) - '0'
pos += 1
}
res
}
}
}
......@@ -2,47 +2,35 @@ package zio.redis.internal
import zio.Chunk
import zio.redis._
import zio.stream.ZStream
import zio.test.Assertion._
import zio.test._
import java.nio.charset.StandardCharsets
object RespValueSpec extends BaseSpec {
def spec: Spec[Any, RedisError.ProtocolError] =
suite("RespValue")(
suite("serialization")(
test("array") {
val expected = Chunk.fromArray("*3\r\n$3\r\nabc\r\n:123\r\n$-1\r\n".getBytes(StandardCharsets.UTF_8))
val v = RespValue.array(RespValue.bulkString("abc"), RespValue.Integer(123), RespValue.NullBulkString)
assert(v.serialize)(equalTo(expected))
}
),
suite("deserialization")(
test("array") {
val values = Chunk(
RespValue.SimpleString("OK"),
test("serializes and deserializes messages") {
val values = Chunk(
RespValue.SimpleString("OK"),
RespValue.bulkString("test1"),
RespValue.array(
RespValue.bulkString("test1"),
RespValue.array(
RespValue.bulkString("test1"),
RespValue.Integer(42L),
RespValue.NullBulkString,
RespValue.array(RespValue.SimpleString("a"), RespValue.Integer(0L)),
RespValue.bulkString("in array"),
RespValue.SimpleString("test2")
),
RespValue.NullBulkString
)
RespValue.Integer(42L),
RespValue.NullBulkString,
RespValue.array(RespValue.SimpleString("a"), RespValue.Integer(0L)),
RespValue.bulkString("in array"),
RespValue.SimpleString("test2")
),
RespValue.NullBulkString
)
zio.stream.ZStream
.fromChunk(values)
.mapConcat(_.serialize)
.via(RespValue.Decoder)
.collect { case Some(value) =>
value
}
.runCollect
.map(assert(_)(equalTo(values)))
}
)
ZStream
.fromChunk(values)
.mapConcat(_.asBytes)
.via(RespValue.Decoder)
.collectSome
.runCollect
.map(assert(_)(equalTo(values)))
}
)
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册