未验证 提交 e49e320e 编写于 作者: A Anatoly Sergeev 提交者: GitHub

Scripting API (eval, evalSha, scriptExists, scriptLoad) (#269)

上级 c655e7b6
......@@ -5,15 +5,24 @@ import java.util.concurrent.TimeUnit
import zio.Chunk
import zio.duration.Duration
import zio.redis.RespValue.BulkString
import zio.schema.Schema
import zio.schema.codec.Codec
sealed trait Input[-A] {
self =>
private[redis] def encode(data: A)(implicit codec: Codec): Chunk[RespValue.BulkString]
final def contramap[B](f: B => A): Input[B] = new Input[B] {
def encode(data: B)(implicit codec: Codec): Chunk[BulkString] = self.encode(f(data))
}
}
object Input {
def apply[A](implicit input: Input[A]): Input[A] = input
@inline
private[this] def encodeString(s: String): RespValue.BulkString = RespValue.bulkString(s)
@inline
......@@ -557,6 +566,21 @@ object Input {
data.foldLeft(Chunk.empty: Chunk[RespValue.BulkString])((acc, a) => acc ++ input.encode(a))
}
final case class EvalInput[-K, -V](inputK: Input[K], inputV: Input[V]) extends Input[(String, Chunk[K], Chunk[V])] {
def encode(data: (String, Chunk[K], Chunk[V]))(implicit codec: Codec): Chunk[RespValue.BulkString] = {
val (lua, keys, args) = data
val encodedScript = Chunk(encodeString(lua), encodeString(keys.size.toString))
val encodedKeys = keys.flatMap(inputK.encode)
val encodedArgs = args.flatMap(inputV.encode)
encodedScript ++ encodedKeys ++ encodedArgs
}
}
case object ScriptDebugInput extends Input[DebugMode] {
def encode(data: DebugMode)(implicit codec: Codec): Chunk[RespValue.BulkString] =
Chunk.single(encodeString(data.stringify))
}
case object WithScoresInput extends Input[WithScores] {
def encode(data: WithScores)(implicit codec: Codec): Chunk[RespValue.BulkString] =
Chunk.single(encodeString(data.stringify))
......
......@@ -18,6 +18,10 @@ sealed trait Output[+A] {
throw RedisError.BusyGroup(msg.drop(9).trim)
case RespValue.Error(msg) if msg.startsWith("NOGROUP") =>
throw RedisError.NoGroup(msg.drop(7).trim)
case RespValue.Error(msg) if msg.startsWith("NOSCRIPT") =>
throw RedisError.NoScript(msg.drop(8).trim)
case RespValue.Error(msg) if msg.startsWith("NOTBUSY") =>
throw RedisError.NotBusy(msg.drop(7).trim)
case RespValue.Error(msg) =>
throw RedisError.ProtocolError(msg.trim)
case success =>
......@@ -37,6 +41,12 @@ object Output {
import RedisError._
def apply[A](implicit output: Output[A]): Output[A] = output
case object RespValueOutput extends Output[RespValue] {
protected def tryDecode(respValue: RespValue)(implicit codec: Codec): RespValue = respValue
}
case object BoolOutput extends Output[Boolean] {
protected def tryDecode(respValue: RespValue)(implicit codec: Codec): Boolean =
respValue match {
......
......@@ -16,5 +16,7 @@ object RedisError {
final case class WrongType(message: String) extends RedisError
final case class BusyGroup(message: String) extends RedisError
final case class NoGroup(message: String) extends RedisError
final case class NoScript(message: String) extends RedisError
final case class NotBusy(message: String) extends RedisError
final case class IOError(exception: IOException) extends RedisError
}
......@@ -27,4 +27,8 @@ object ResultBuilder {
trait ResultBuilder3[+F[_, _, _]] extends ResultBuilder {
def returning[R1: Schema, R2: Schema, R3: Schema]: ZIO[RedisExecutor, RedisError, F[R1, R2, R3]]
}
trait ResultOutputBuilder extends ResultBuilder {
def returning[R: Output]: ZIO[RedisExecutor, RedisError, R]
}
}
package zio.redis.api
import zio.{Chunk, ZIO}
import zio.redis._
import zio.redis.Input._
import zio.redis.Output._
import zio.redis.ResultBuilder.ResultOutputBuilder
trait Scripting {
import Scripting._
/**
* Evaluates a Lua script
*
* Example of custom data:
* {{{
* final case class Person(name: String, age: Long)
*
* val encoder = Input.Tuple2(StringInput, LongInput).contramap[Person] { civ =>
* (civ.name, civ.age)
* }
* val decoder = RespValueOutput.map {
* case RespValue.Array(elements) =>
* val name = elements(0) match {
* case s @ RespValue.BulkString(_) => s.asString
* case other => throw ProtocolError(s"$other isn't a string type")
* }
* val age = elements(1) match {
* case RespValue.Integer(value) => value
* case other => throw ProtocolError(s"$other isn't a integer type")
* }
* Person(name, age)
* case other => throw ProtocolError(s"$other isn't an array type")
* }
* }}}
*
* @param script
* Lua script
* @param keys
* keys available through KEYS param in the script
* @param args
* values available through ARGV param in the script
* @return
* redis protocol value that is converted from the Lua type. You have to write decoder that would convert redis
* protocol value to a suitable type for your app
*/
def eval[K: Input, A: Input](
script: String,
keys: Chunk[K],
args: Chunk[A]
): ResultOutputBuilder = new ResultOutputBuilder {
def returning[R: Output]: ZIO[RedisExecutor, RedisError, R] = {
val command = RedisCommand(Eval, EvalInput(Input[K], Input[A]), Output[R])
command.run((script, keys, args))
}
}
/**
* Evaluates a Lua script cached on the server side by its SHA1 digest. Scripts could be cached using the
* [[zio.redis.api.Scripting.scriptLoad]] method.
*
* @param sha1
* SHA1 digest
* @param keys
* keys available through KEYS param in the script
* @param args
* values available through ARGV param in the script
* @return
* redis protocol value that is converted from the Lua type. You have to write decoder that would convert redis
* protocol value to a suitable type for your app
*/
def evalSha[K: Input, A: Input](
sha1: String,
keys: Chunk[K],
args: Chunk[A]
): ResultOutputBuilder = new ResultOutputBuilder {
def returning[R: Output]: ZIO[RedisExecutor, RedisError, R] = {
val command = RedisCommand(EvalSha, EvalInput(Input[K], Input[A]), Output[R])
command.run((sha1, keys, args))
}
}
/**
* Checks existence of the scripts in the script cache.
*
* @param sha1
* one required SHA1 digest
* @param sha1s
* maybe rest of the SHA1 digests
* @return
* for every corresponding SHA1 digest of a script that actually exists in the script cache, an true is returned,
* otherwise false is returned.
*/
def scriptExists(sha1: String, sha1s: String*): ZIO[RedisExecutor, RedisError, Chunk[Boolean]] = {
val command = RedisCommand(ScriptExists, NonEmptyList(StringInput), ChunkOutput(BoolOutput))
command.run((sha1, sha1s.toList))
}
/**
* Loads a script into the scripts cache. After the script is loaded into the script cache it could be evaluated using
* the [[zio.redis.api.Scripting.evalSha]] method.
*
* @param script
* Lua script
* @return
* the SHA1 digest of the script added into the script cache.
*/
def scriptLoad(script: String): ZIO[RedisExecutor, RedisError, String] = {
val command = RedisCommand(ScriptLoad, StringInput, MultiStringOutput)
command.run(script)
}
}
private[redis] object Scripting {
final val Eval = "EVAL"
final val EvalSha = "EVALSHA"
final val ScriptExists = "SCRIPT EXISTS"
final val ScriptLoad = "SCRIPT LOAD"
}
package zio.redis.options
trait Scripting {
sealed trait DebugMode { self =>
private[redis] final def stringify: String =
self match {
case DebugMode.Yes => "YES"
case DebugMode.Sync => "SYNC"
case DebugMode.No => "NO"
}
}
object DebugMode {
case object Yes extends DebugMode
case object Sync extends DebugMode
case object No extends DebugMode
}
}
......@@ -11,6 +11,7 @@ package object redis
with api.Strings
with api.SortedSets
with api.Streams
with api.Scripting
with options.Connection
with options.Geo
with options.Keys
......@@ -18,7 +19,8 @@ package object redis
with options.SortedSets
with options.Strings
with options.Lists
with options.Streams {
with options.Streams
with options.Scripting {
type Id[+A] = A
type RedisExecutor = Has[RedisExecutor.Service]
......
......@@ -17,7 +17,8 @@ object ApiSpec
with GeoSpec
with HyperLogLogSpec
with HashSpec
with StreamsSpec {
with StreamsSpec
with ScriptingSpec {
def spec: ZSpec[TestEnvironment, Failure] =
suite("Redis commands")(
......@@ -31,7 +32,8 @@ object ApiSpec
geoSuite,
hyperLogLogSuite,
hashSuite,
streamsSuite
streamsSuite,
scriptingSpec
).provideCustomLayerShared((Logging.ignore ++ ZLayer.succeed(codec) >>> RedisExecutor.local.orDie) ++ Clock.live)
@@ sequential,
suite("Test Executor")(
......
......@@ -42,13 +42,9 @@ object OutputSpec extends BaseSpec {
} yield assert(res)(isEmpty)
},
testM("extract non-empty arrays") {
val respValue = RespValue.array(RespValue.bulkString("foo"), RespValue.bulkString("bar"))
for {
res <-
Task(
ChunkOutput(MultiStringOutput).unsafeDecode(
RespValue.array(RespValue.bulkString("foo"), RespValue.bulkString("bar"))
)
)
res <- Task(ChunkOutput(MultiStringOutput).unsafeDecode(respValue))
} yield assert(res)(hasSameElements(Chunk("foo", "bar")))
}
),
......
package zio.redis
import zio._
import zio.redis.Input.{BoolInput, ByteInput, LongInput, StringInput}
import zio.redis.Output._
import zio.redis.RedisError._
import zio.redis.ScriptingSpec._
import zio.test._
import zio.test.Assertion._
import scala.util.Random
trait ScriptingSpec extends BaseSpec {
val scriptingSpec: Spec[Annotations with RedisExecutor, TestFailure[Any], TestSuccess] =
suite("scripting")(
suite("eval")(
testM("put boolean and return existence of key") {
for {
key <- uuid
arg = true
lua =
"""
|redis.call('set',KEYS[1],ARGV[1])
|return redis.call('exists',KEYS[1])
""".stripMargin
res <- eval(lua, Chunk(key), Chunk(arg)).returning[Boolean]
} yield assertTrue(res)
},
testM("take strings return strings") {
for {
key1 <- uuid
key2 <- uuid
arg1 <- uuid
arg2 <- uuid
lua = """return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}"""
res <- eval(lua, Chunk(key1, key2), Chunk(arg1, arg2)).returning[Chunk[String]]
} yield assertTrue(res == Chunk(key1, key2, arg1, arg2))
},
testM("put custom input value return custom input value") {
for {
key1 <- uuid
key2 <- uuid
arg1 <- uuid
arg2 <- ZIO.succeedNow(Random.nextLong())
arg = CustomInputValue(arg1, arg2)
lua = """return {ARGV[1],ARGV[2]}"""
res <- eval(lua, Chunk(key1, key2), Chunk(arg)).returning[Map[String, String]]
} yield assertTrue(res == Map(arg1 -> arg2.toString))
},
testM("return custom data type") {
val lua = """return {1,2,{3,'Hello World!'}}"""
val expected = CustomData(1, 2, (3, "Hello World!"))
val emptyInput: Chunk[Long] = Chunk.empty
for {
res <- eval(lua, emptyInput, emptyInput).returning[CustomData]
} yield assertTrue(res == expected)
},
testM("throw an error when incorrect script's sent") {
for {
key <- uuid
arg <- uuid
lua = ";"
error = "Error compiling script (new function): user_script:1: unexpected symbol near ';'"
res <- eval(lua, Chunk(key), Chunk(arg)).returning[String].either
} yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(error)))))
},
testM("throw an error if couldn't decode resp value") {
val customError = "custom error"
implicit val decoder: Output[String] = errorOutput(customError)
for {
key <- uuid
arg <- uuid
lua = ""
res <- eval(lua, Chunk(key), Chunk(arg)).returning[String](decoder).either
} yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(customError)))))
},
testM("throw custom error from script") {
for {
key <- uuid
arg <- uuid
myError = "My Error"
lua = s"""return redis.error_reply("${myError}")"""
res <- eval(lua, Chunk(key), Chunk(arg)).returning[String].either
} yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(myError)))))
}
),
suite("evalSHA")(
testM("put boolean and return existence of key") {
for {
key <- uuid
arg = true
lua =
"""
|redis.call('set',KEYS[1],ARGV[1])
|return redis.call('exists',KEYS[1])
""".stripMargin
sha <- scriptLoad(lua)
res <- evalSha(sha, Chunk(key), Chunk(arg)).returning[Boolean]
} yield assertTrue(res)
},
testM("take strings return strings") {
for {
key1 <- uuid
key2 <- uuid
arg1 <- uuid
arg2 <- uuid
lua = """return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}"""
sha <- scriptLoad(lua)
res <- evalSha(sha, Chunk(key1, key2), Chunk(arg1, arg2)).returning[Chunk[String]]
} yield assertTrue(res == Chunk(key1, key2, arg1, arg2))
},
testM("return custom data type") {
val lua = """return {1,2,{3,'Hello World!'}}"""
val expected = CustomData(1, 2, (3, "Hello World!"))
val emptyInput: Chunk[String] = Chunk.empty
for {
res <- eval(lua, emptyInput, emptyInput).returning[CustomData]
} yield assertTrue(res == expected)
},
testM("throw an error if couldn't decode resp value") {
val customError = "custom error"
implicit val decoder: Output[String] = errorOutput(customError)
for {
key <- uuid
arg <- uuid
lua = ""
sha <- scriptLoad(lua)
res <- evalSha(sha, Chunk(key), Chunk(arg)).returning[String](decoder).either
} yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(customError)))))
},
testM("throw custom error from script") {
for {
key <- uuid
arg <- uuid
myError = "My Error"
lua = s"""return redis.error_reply("${myError}")"""
sha <- scriptLoad(lua)
res <- evalSha(sha, Chunk(key), Chunk(arg)).returning[String].either
} yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(myError)))))
},
testM("throw NoScript error if script isn't found in cache") {
val lua = """return "1""""
val error = "No matching script. Please use EVAL."
val emptyInput: Chunk[String] = Chunk.empty
for {
res <- evalSha(lua, emptyInput, emptyInput).returning[String].either
} yield assert(res)(isLeft(isSubtype[NoScript](hasField("message", _.message, equalTo(error)))))
}
),
suite("scriptExists")(
testM("return true if scripts are found in the cache") {
val lua1 = """return "1""""
val lua2 = """return "2""""
for {
sha1 <- scriptLoad(lua1)
sha2 <- scriptLoad(lua2)
res <- scriptExists(sha1, sha2)
} yield assertTrue(res == Chunk(true, true))
},
testM("return false if scripts aren't found in the cache") {
val lua1 = """return "1""""
val lua2 = """return "2""""
for {
res <- scriptExists(lua1, lua2)
} yield assertTrue(res == Chunk(false, false))
}
),
suite("scriptLoad")(
testM("return OK") {
val lua = """return "1""""
for {
sha <- scriptLoad(lua)
} yield assert(sha)(isSubtype[String](anything))
},
testM("throw an error when incorrect script was sent") {
val lua = ";"
val error = "Error compiling script (new function): user_script:1: unexpected symbol near ';'"
for {
sha <- scriptLoad(lua).either
} yield assert(sha)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(error)))))
}
)
)
}
object ScriptingSpec {
final case class CustomInputValue(name: String, age: Long)
object CustomInputValue {
implicit val encoder: Input[CustomInputValue] = Input.Tuple2(StringInput, LongInput).contramap { civ =>
(civ.name, civ.age)
}
}
final case class CustomData(count: Long, avg: Long, pair: (Int, String))
object CustomData {
implicit val decoder: Output[CustomData] = RespValueOutput.map {
case RespValue.Array(elements) =>
val count = RespToLong(elements(0))
val avg = RespToLong(elements(1))
val pair = elements(2) match {
case RespValue.Array(elements) => (RespToLong(elements(0)).toInt, RespToString(elements(1)))
case other => throw ProtocolError(s"$other isn't an array type")
}
CustomData(count, avg, pair)
case other => throw ProtocolError(s"$other isn't an array type")
}
}
private val RespToLong: RespValue => Long = {
case RespValue.Integer(value) => value
case other => throw ProtocolError(s"$other isn't a integer type")
}
private val RespToString: RespValue => String = {
case s @ RespValue.BulkString(_) => s.asString
case other => throw ProtocolError(s"$other isn't a string type")
}
implicit val bytesEncoder: Input[Chunk[Byte]] = ByteInput
implicit val booleanInput: Input[Boolean] = BoolInput
implicit val stringInput: Input[String] = StringInput
implicit val longInput: Input[Long] = LongInput
implicit val keyValueOutput: KeyValueOutput[String, String] = KeyValueOutput(MultiStringOutput, MultiStringOutput)
implicit val booleanOutput: Output[Boolean] = RespValueOutput.map {
case RespValue.Integer(0) => false
case RespValue.Integer(1) => true
case other => throw ProtocolError(s"$other isn't a string nor an array")
}
implicit val simpleStringOutput: Output[String] = RespValueOutput.map {
case RespValue.SimpleString(value) => value
case other => throw ProtocolError(s"$other isn't a string nor an array")
}
implicit val chunkStringOutput: Output[Chunk[String]] = RespValueOutput.map {
case RespValue.Array(elements) =>
elements.map {
case s @ RespValue.BulkString(_) => s.asString
case other => throw ProtocolError(s"$other isn't a bulk string")
}
case other => throw ProtocolError(s"$other isn't a string nor an array")
}
def errorOutput(error: String): Output[String] = RespValueOutput.map { _ =>
throw ProtocolError(error)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册