RespValue.scala 8.5 KB
Newer Older
D
Dejan Mijić 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/*
 * Copyright 2021 John A. De Goes and the ZIO contributors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

17
package zio.redis.internal
18

19
import zio._
A
Anatoly Sergeev 已提交
20
import zio.redis.options.Cluster.Slot
21
import zio.redis.{RedisError, RedisUri}
22
import zio.stream._
23

D
Dejan Mijić 已提交
24 25
import java.nio.charset.StandardCharsets

26
private[redis] sealed trait RespValue extends Product with Serializable { self =>
27
  import RespValue._
梦境迷离's avatar
梦境迷离 已提交
28
  import RespValue.internal.{CrLf, Headers, NullArrayEncoded, NullStringEncoded}
29

D
Dejan Mijić 已提交
30
  final def asBytes: Chunk[Byte] =
31
    self match {
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
      case NullBulkString => NullStringEncoded
      case NullArray      => NullArrayEncoded
      case SimpleString(s) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.SimpleString
        builder ++= encode(s)
        builder ++= CrLf

        builder.result()

      case Error(s) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.Error
        builder ++= encode(s)
        builder ++= CrLf

        builder.result()

      case Integer(i) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.Integer
        builder ++= encode(i.toString)
        builder ++= CrLf

        builder.result()
60

61
      case BulkString(bytes) =>
62 63 64 65 66 67 68 69 70
        val builder = new ChunkBuilder.Byte()

        builder += Headers.BulkString
        builder ++= encode(bytes.length.toString)
        builder ++= CrLf
        builder ++= bytes
        builder ++= CrLf

        builder.result()
71

D
Dejan Mijić 已提交
72
      case Array(elements) =>
73 74 75 76 77 78 79 80 81
        val builder = new ChunkBuilder.Byte()

        builder += Headers.Array
        builder ++= encode(elements.size.toString)
        builder ++= CrLf

        elements.foreach(builder ++= _.asBytes)

        builder.result()
82 83
    }

84
  private[this] def encode(s: String) = s.getBytes(StandardCharsets.US_ASCII)
85 86
}

87
private[redis] object RespValue {
88
  final case class SimpleString(value: String) extends RespValue
89

A
Anatoly Sergeev 已提交
90
  final case class Error(value: String) extends RespValue {
D
Dejan Mijić 已提交
91
    def asRedisError: RedisError =
A
Anatoly Sergeev 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
      if (value.startsWith("ERR")) RedisError.ProtocolError(value.drop(3).trim)
      else if (value.startsWith("WRONGTYPE")) RedisError.WrongType(value.drop(9).trim)
      else if (value.startsWith("BUSYGROUP")) RedisError.BusyGroup(value.drop(9).trim)
      else if (value.startsWith("NOGROUP")) RedisError.NoGroup(value.drop(7).trim)
      else if (value.startsWith("NOSCRIPT")) RedisError.NoScript(value.drop(8).trim)
      else if (value.startsWith("NOTBUSY")) RedisError.NotBusy(value.drop(7).trim)
      else if (value.startsWith("CROSSSLOT")) RedisError.CrossSlot(value.drop(9).trim)
      else if (value.startsWith("ASK")) RedisError.Ask(parseRedirectError(value))
      else if (value.startsWith("MOVED")) RedisError.Moved(parseRedirectError(value))
      else RedisError.ProtocolError(value.trim)

    private def parseRedirectError(value: String) = {
      val splittingError = value.split(' ')
      (Slot(splittingError(1).toLong), RedisUri(splittingError(2)))
    }
  }
108

109
  final case class Integer(value: Long) extends RespValue
110

111
  final case class BulkString(value: Chunk[Byte]) extends RespValue {
D
Dejan Mijić 已提交
112
    def asLong: Long = internal.unsafeReadLong(value, 0)
113

D
Dejan Mijić 已提交
114
    def asString: String = internal.decode(value)
115 116
  }

117
  final case class Array(values: Chunk[RespValue]) extends RespValue
118

119 120 121
  case object NullBulkString extends RespValue

  case object NullArray extends RespValue
122

123 124 125 126 127 128 129 130
  object ArrayValues {
    def unapplySeq(v: RespValue): Option[Seq[RespValue]] =
      v match {
        case Array(values) => Some(values)
        case _             => None
      }
  }

131
  final val Decoder: ZPipeline[Any, RedisError.ProtocolError, Byte, Option[RespValue]] = {
D
Dejan Mijić 已提交
132 133 134 135
    import internal.State

    // ZSink fold will return a State.Start when contFn is false
    val lineProcessor =
D
Dejan Mijić 已提交
136
      ZSink.foldChunks[Byte, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
137
        case State.Done(value) => ZIO.some(value)
梦境迷离's avatar
梦境迷离 已提交
138
        case State.Failed      => ZIO.fail(RedisError.ProtocolError("Invalid data received."))
139
        case State.Start       => ZIO.none
D
Dejan Mijić 已提交
140
        case other             => ZIO.dieMessage(s"Deserialization bug, should not get $other")
梦境迷离's avatar
梦境迷离 已提交
141 142
      }

D
Dejan Mijić 已提交
143
    ZPipeline.splitOnChunk(internal.CrLf) >>> ZPipeline.fromSink(lineProcessor)
144 145
  }

146
  def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))
147

148
  def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))
149

150 151 152 153 154 155 156
  private object internal {
    object Headers {
      final val SimpleString: Byte = '+'
      final val Error: Byte        = '-'
      final val Integer: Byte      = ':'
      final val BulkString: Byte   = '$'
      final val Array: Byte        = '*'
157 158
    }

159
    final val CrLf: Chunk[Byte]              = Chunk('\r', '\n')
D
Dejan Mijić 已提交
160 161 162 163
    final val NullArrayEncoded: Chunk[Byte]  = Chunk('*', '-', '1', '\r', '\n')
    final val NullArrayPrefix: Chunk[Byte]   = Chunk('*', '-', '1')
    final val NullStringEncoded: Chunk[Byte] = Chunk('$', '-', '1', '\r', '\n')
    final val NullStringPrefix: Chunk[Byte]  = Chunk('$', '-', '1')
164 165 166

    sealed trait State { self =>
      import State._
167

168 169 170 171 172
      final def inProgress: Boolean =
        self match {
          case Done(_) | Failed => false
          case _                => true
        }
173

D
Dejan Mijić 已提交
174
      final def feed(bytes: Chunk[Byte]): State =
175
        self match {
D
Dejan Mijić 已提交
176 177 178 179 180 181 182 183 184
          case Start if bytes.isEmpty             => Start
          case Start if bytes == NullStringPrefix => Done(NullBulkString)
          case Start if bytes == NullArrayPrefix  => Done(NullArray)

          case Start if bytes.nonEmpty =>
            bytes.head match {
              case Headers.SimpleString => Done(SimpleString(decode(bytes.tail)))
              case Headers.Error        => Done(Error(decode(bytes.tail)))
              case Headers.Integer      => Done(Integer(unsafeReadLong(bytes, 1)))
185
              case Headers.BulkString =>
D
Dejan Mijić 已提交
186 187
                val size = unsafeReadSize(bytes)
                CollectingBulkString(size, ChunkBuilder.make(size))
188
              case Headers.Array =>
D
Dejan Mijić 已提交
189
                val size = unsafeReadSize(bytes)
190 191

                if (size > 0)
192
                  CollectingArray(size, ChunkBuilder.make(size), Start.feed)
193 194
                else
                  Done(Array(Chunk.empty))
D
Dejan Mijić 已提交
195 196

              case _ => Failed
197 198 199
            }

          case CollectingArray(rem, vals, next) =>
D
Dejan Mijić 已提交
200
            next(bytes) match {
201 202
              case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed)
              case Done(v)            => Done(Array((vals += v).result()))
203 204 205
              case state              => CollectingArray(rem, vals, state.feed)
            }

206
          case CollectingBulkString(rem, vals) =>
D
Dejan Mijić 已提交
207 208 209
            if (bytes.length >= rem) {
              vals ++= bytes.take(rem)
              Done(BulkString(vals.result()))
210
            } else {
D
Dejan Mijić 已提交
211 212 213
              vals ++= bytes
              vals ++= CrLf
              CollectingBulkString(rem - bytes.length - 2, vals)
214 215 216
            }

          case _ => Failed
217
        }
218 219
    }

220
    object State {
D
Dejan Mijić 已提交
221 222 223 224 225 226 227 228 229
      case object Start  extends State
      case object Failed extends State

      final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: Chunk[Byte] => State)
          extends State

      final case class CollectingBulkString(rem: Int, vals: ChunkBuilder[Byte]) extends State

      final case class Done(value: RespValue) extends State
230 231
    }

D
Dejan Mijić 已提交
232 233 234
    def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)

    def unsafeReadLong(bytes: Chunk[Byte], startFrom: Int): Long = {
235
      var pos = startFrom
236 237 238
      var res = 0L
      var neg = false

D
Dejan Mijić 已提交
239
      if (bytes(pos) == '-') {
240 241
        neg = true
        pos += 1
242
      }
243

D
Dejan Mijić 已提交
244
      val len = bytes.length
245 246

      while (pos < len) {
D
Dejan Mijić 已提交
247
        res = res * 10 + bytes(pos) - '0'
248 249 250 251 252
        pos += 1
      }

      if (neg) -res else res
    }
D
Dejan Mijić 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265

    def unsafeReadSize(bytes: Chunk[Byte]): Int = {
      var pos = 1
      var res = 0
      val len = bytes.length

      while (pos < len) {
        res = res * 10 + bytes(pos) - '0'
        pos += 1
      }

      res
    }
266 267
  }
}