未验证 提交 ce380903 编写于 作者: D Dejan Mijić 提交者: GitHub

Improve array serialization (#265)

上级 e259e911
package zio.redis
import java.util.concurrent.TimeUnit
import org.openjdk.jmh.annotations._
import zio.ZIO
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
@Measurement(iterations = 15)
@Warmup(iterations = 15)
@Fork(2)
class SMembersBenchmarks extends BenchmarkRuntime {
@Param(Array("500"))
private var size: Int = _
private var items: List[String] = _
private val key = "test-set"
@Setup(Level.Trial)
def setup(): Unit = {
items = (0 to size).toList.map(_.toString)
zioUnsafeRun(sAdd(key, items.head, items.tail: _*).unit)
}
@Benchmark
def laserdisc(): Unit = {
import _root_.laserdisc.fs2._
import _root_.laserdisc.{ all => cmd, _ }
import cats.instances.list._
import cats.syntax.foldable._
unsafeRun[LaserDiscClient](c => items.traverse_(_ => c.send(cmd.smembers(Key.unsafeFrom(key)))))
}
@Benchmark
def rediculous(): Unit = {
import cats.implicits._
import io.chrisdavenport.rediculous._
unsafeRun[RediculousClient](c => items.traverse_(_ => RedisCommands.smembers[RedisIO](key).run(c)))
}
@Benchmark
def redis4cats(): Unit = {
import cats.instances.list._
import cats.syntax.foldable._
unsafeRun[Redis4CatsClient[String]](c => items.traverse_(_ => c.sMembers(key)))
}
@Benchmark
def zio(): Unit = zioUnsafeRun(ZIO.foreach_(items)(_ => sMembers(key)))
}
......@@ -36,14 +36,6 @@ object Output {
import RedisError._
private def decodeDouble(bytes: Chunk[Byte]): Double = {
val text = RespValue.decodeString(bytes)
try text.toDouble
catch {
case _: NumberFormatException => throw ProtocolError(s"'$text' isn't a double.")
}
}
case object BoolOutput extends Output[Boolean] {
protected def tryDecode(respValue: RespValue): Boolean =
respValue match {
......@@ -402,4 +394,12 @@ object Output {
case other => throw ProtocolError(s"$other isn't a valid set response")
}
}
private def decodeDouble(bytes: Chunk[Byte]): Double = {
val text = RespValue.decode(bytes)
try text.toDouble
catch {
case _: NumberFormatException => throw ProtocolError(s"'$text' isn't a double.")
}
}
}
......@@ -82,7 +82,7 @@ object RedisExecutor {
private def receive: IO[RedisError, Unit] =
byteStream.read
.mapError(RedisError.IOError)
.transduce(RespValue.Deserializer)
.transduce(RespValue.Decoder)
.foreach(response => resQueue.take.flatMap(_.succeed(response)))
}
......
......@@ -36,7 +36,7 @@ object RespValue {
final case class Integer(value: Long) extends RespValue
final case class BulkString(value: Chunk[Byte]) extends RespValue {
private[redis] def asString: String = decodeString(value)
private[redis] def asString: String = decode(value)
private[redis] def asLong: Long = internal.unsafeReadLong(asString, 0)
}
......@@ -53,17 +53,7 @@ object RespValue {
}
}
def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))
def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))
def decodeString(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)
private[redis] final val Cr: Byte = '\r'
private[redis] final val Lf: Byte = '\n'
private[redis] final val Deserializer: Transducer[RedisError.ProtocolError, Byte, RespValue] = {
private[redis] final val Decoder: Transducer[RedisError.ProtocolError, Byte, RespValue] = {
import internal.State
val processLine =
......@@ -78,6 +68,12 @@ object RespValue {
Transducer.utf8Decode >>> Transducer.splitLines >>> processLine
}
private[redis] def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))
private[redis] def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))
private[redis] def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)
private object internal {
object Headers {
final val SimpleString: Byte = '+'
......@@ -87,7 +83,7 @@ object RespValue {
final val Array: Byte = '*'
}
final val CrLf: Chunk[Byte] = Chunk(Cr, Lf)
final val CrLf: Chunk[Byte] = Chunk('\r', '\n')
final val NullArray: String = "*-1"
final val NullValue: String = "$-1"
final val NullString: Chunk[Byte] = Chunk.fromArray("$-1\r\n".getBytes(StandardCharsets.US_ASCII))
......@@ -115,15 +111,15 @@ object RespValue {
val size = unsafeReadLong(line, 1).toInt
if (size > 0)
CollectingArray(size, Chunk.empty, Start.feed)
CollectingArray(size, ChunkBuilder.make(size), Start.feed)
else
Done(Array(Chunk.empty))
}
case CollectingArray(rem, vals, next) =>
next(line) match {
case Done(v) if rem > 1 => CollectingArray(rem - 1, vals :+ v, Start.feed)
case Done(v) => Done(Array(vals :+ v))
case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed)
case Done(v) => Done(Array((vals += v).result()))
case state => CollectingArray(rem, vals, state.feed)
}
......@@ -133,11 +129,11 @@ object RespValue {
}
object State {
case object Start extends State
case object ExpectingBulk extends State
case object Failed extends State
final case class CollectingArray(rem: Int, vals: Chunk[RespValue], next: String => State) extends State
final case class Done(value: RespValue) extends State
case object Start extends State
case object ExpectingBulk extends State
case object Failed extends State
final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: String => State) extends State
final case class Done(value: RespValue) extends State
}
def unsafeReadLong(text: String, startFrom: Int): Long = {
......
......@@ -35,7 +35,7 @@ object RespValueSpec extends BaseSpec {
Stream
.fromChunk(values)
.mapConcat(_.serialize)
.transduce(RespValue.Deserializer)
.transduce(RespValue.Decoder)
.runCollect
.map(assert(_)(equalTo(values)))
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册