/* * Copyright 2021 John A. De Goes and the ZIO contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package zio.redis.internal import zio._ import zio.redis.options.Cluster.Slot import zio.redis.{RedisError, RedisUri} import zio.stream._ import java.nio.charset.StandardCharsets private[redis] sealed trait RespValue extends Product with Serializable { self => import RespValue._ import RespValue.internal.{CrLf, Headers, NullArrayEncoded, NullStringEncoded} final def asBytes: Chunk[Byte] = self match { case NullBulkString => NullStringEncoded case NullArray => NullArrayEncoded case SimpleString(s) => val builder = new ChunkBuilder.Byte() builder += Headers.SimpleString builder ++= encode(s) builder ++= CrLf builder.result() case Error(s) => val builder = new ChunkBuilder.Byte() builder += Headers.Error builder ++= encode(s) builder ++= CrLf builder.result() case Integer(i) => val builder = new ChunkBuilder.Byte() builder += Headers.Integer builder ++= encode(i.toString) builder ++= CrLf builder.result() case BulkString(bytes) => val builder = new ChunkBuilder.Byte() builder += Headers.BulkString builder ++= encode(bytes.length.toString) builder ++= CrLf builder ++= bytes builder ++= CrLf builder.result() case Array(elements) => val builder = new ChunkBuilder.Byte() builder += Headers.Array builder ++= encode(elements.size.toString) builder ++= CrLf elements.foreach(builder ++= _.asBytes) builder.result() } private[this] def encode(s: String) = s.getBytes(StandardCharsets.US_ASCII) } private[redis] object RespValue { final case class SimpleString(value: String) extends RespValue final case class Error(value: String) extends RespValue { def asRedisError: RedisError = if (value.startsWith("ERR")) RedisError.ProtocolError(value.drop(3).trim) else if (value.startsWith("WRONGTYPE")) RedisError.WrongType(value.drop(9).trim) else if (value.startsWith("BUSYGROUP")) RedisError.BusyGroup(value.drop(9).trim) else if (value.startsWith("NOGROUP")) RedisError.NoGroup(value.drop(7).trim) else if (value.startsWith("NOSCRIPT")) RedisError.NoScript(value.drop(8).trim) else if (value.startsWith("NOTBUSY")) RedisError.NotBusy(value.drop(7).trim) else if (value.startsWith("CROSSSLOT")) RedisError.CrossSlot(value.drop(9).trim) else if (value.startsWith("ASK")) RedisError.Ask(parseRedirectError(value)) else if (value.startsWith("MOVED")) RedisError.Moved(parseRedirectError(value)) else RedisError.ProtocolError(value.trim) private def parseRedirectError(value: String) = { val splittingError = value.split(' ') (Slot(splittingError(1).toLong), RedisUri(splittingError(2))) } } final case class Integer(value: Long) extends RespValue final case class BulkString(value: Chunk[Byte]) extends RespValue { def asLong: Long = internal.unsafeReadLong(value, 0) def asString: String = internal.decode(value) } final case class Array(values: Chunk[RespValue]) extends RespValue case object NullBulkString extends RespValue case object NullArray extends RespValue object ArrayValues { def unapplySeq(v: RespValue): Option[Seq[RespValue]] = v match { case Array(values) => Some(values) case _ => None } } final val Decoder: ZPipeline[Any, RedisError.ProtocolError, Byte, Option[RespValue]] = { import internal.State // ZSink fold will return a State.Start when contFn is false val lineProcessor = ZSink.foldChunks[Byte, State](State.Start)(_.inProgress)(_ feed _).mapZIO { case State.Done(value) => ZIO.some(value) case State.Failed => ZIO.fail(RedisError.ProtocolError("Invalid data received.")) case State.Start => ZIO.none case other => ZIO.dieMessage(s"Deserialization bug, should not get $other") } ZPipeline.splitOnChunk(internal.CrLf) >>> ZPipeline.fromSink(lineProcessor) } def array(values: RespValue*): Array = Array(Chunk.fromIterable(values)) def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8))) private object internal { object Headers { final val SimpleString: Byte = '+' final val Error: Byte = '-' final val Integer: Byte = ':' final val BulkString: Byte = '$' final val Array: Byte = '*' } final val CrLf: Chunk[Byte] = Chunk('\r', '\n') final val NullArrayEncoded: Chunk[Byte] = Chunk('*', '-', '1', '\r', '\n') final val NullArrayPrefix: Chunk[Byte] = Chunk('*', '-', '1') final val NullStringEncoded: Chunk[Byte] = Chunk('$', '-', '1', '\r', '\n') final val NullStringPrefix: Chunk[Byte] = Chunk('$', '-', '1') sealed trait State { self => import State._ final def inProgress: Boolean = self match { case Done(_) | Failed => false case _ => true } final def feed(bytes: Chunk[Byte]): State = self match { case Start if bytes.isEmpty => Start case Start if bytes == NullStringPrefix => Done(NullBulkString) case Start if bytes == NullArrayPrefix => Done(NullArray) case Start if bytes.nonEmpty => bytes.head match { case Headers.SimpleString => Done(SimpleString(decode(bytes.tail))) case Headers.Error => Done(Error(decode(bytes.tail))) case Headers.Integer => Done(Integer(unsafeReadLong(bytes, 1))) case Headers.BulkString => val size = unsafeReadSize(bytes) CollectingBulkString(size, ChunkBuilder.make(size)) case Headers.Array => val size = unsafeReadSize(bytes) if (size > 0) CollectingArray(size, ChunkBuilder.make(size), Start.feed) else Done(Array(Chunk.empty)) case _ => Failed } case CollectingArray(rem, vals, next) => next(bytes) match { case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed) case Done(v) => Done(Array((vals += v).result())) case state => CollectingArray(rem, vals, state.feed) } case CollectingBulkString(rem, vals) => if (bytes.length >= rem) { vals ++= bytes.take(rem) Done(BulkString(vals.result())) } else { vals ++= bytes vals ++= CrLf CollectingBulkString(rem - bytes.length - 2, vals) } case _ => Failed } } object State { case object Start extends State case object Failed extends State final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: Chunk[Byte] => State) extends State final case class CollectingBulkString(rem: Int, vals: ChunkBuilder[Byte]) extends State final case class Done(value: RespValue) extends State } def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8) def unsafeReadLong(bytes: Chunk[Byte], startFrom: Int): Long = { var pos = startFrom var res = 0L var neg = false if (bytes(pos) == '-') { neg = true pos += 1 } val len = bytes.length while (pos < len) { res = res * 10 + bytes(pos) - '0' pos += 1 } if (neg) -res else res } def unsafeReadSize(bytes: Chunk[Byte]): Int = { var pos = 1 var res = 0 val len = bytes.length while (pos < len) { res = res * 10 + bytes(pos) - '0' pos += 1 } res } } }