From e49e320e3b2d0b37a8763745fa8f333566fabceb Mon Sep 17 00:00:00 2001 From: Anatoly Sergeev Date: Thu, 6 Jan 2022 15:31:49 +0300 Subject: [PATCH] Scripting API (eval, evalSha, scriptExists, scriptLoad) (#269) --- redis/src/main/scala/zio/redis/Input.scala | 24 ++ redis/src/main/scala/zio/redis/Output.scala | 10 + .../src/main/scala/zio/redis/RedisError.scala | 2 + .../main/scala/zio/redis/ResultBuilder.scala | 4 + .../main/scala/zio/redis/api/Scripting.scala | 119 +++++++++ .../scala/zio/redis/options/Scripting.scala | 18 ++ redis/src/main/scala/zio/redis/package.scala | 4 +- redis/src/test/scala/zio/redis/ApiSpec.scala | 6 +- .../src/test/scala/zio/redis/OutputSpec.scala | 8 +- .../test/scala/zio/redis/ScriptingSpec.scala | 248 ++++++++++++++++++ 10 files changed, 434 insertions(+), 9 deletions(-) create mode 100644 redis/src/main/scala/zio/redis/api/Scripting.scala create mode 100644 redis/src/main/scala/zio/redis/options/Scripting.scala create mode 100644 redis/src/test/scala/zio/redis/ScriptingSpec.scala diff --git a/redis/src/main/scala/zio/redis/Input.scala b/redis/src/main/scala/zio/redis/Input.scala index 05a7e77..cc4fe2c 100644 --- a/redis/src/main/scala/zio/redis/Input.scala +++ b/redis/src/main/scala/zio/redis/Input.scala @@ -5,15 +5,24 @@ import java.util.concurrent.TimeUnit import zio.Chunk import zio.duration.Duration +import zio.redis.RespValue.BulkString import zio.schema.Schema import zio.schema.codec.Codec sealed trait Input[-A] { + self => + private[redis] def encode(data: A)(implicit codec: Codec): Chunk[RespValue.BulkString] + + final def contramap[B](f: B => A): Input[B] = new Input[B] { + def encode(data: B)(implicit codec: Codec): Chunk[BulkString] = self.encode(f(data)) + } } object Input { + def apply[A](implicit input: Input[A]): Input[A] = input + @inline private[this] def encodeString(s: String): RespValue.BulkString = RespValue.bulkString(s) @inline @@ -557,6 +566,21 @@ object Input { data.foldLeft(Chunk.empty: Chunk[RespValue.BulkString])((acc, a) => acc ++ input.encode(a)) } + final case class EvalInput[-K, -V](inputK: Input[K], inputV: Input[V]) extends Input[(String, Chunk[K], Chunk[V])] { + def encode(data: (String, Chunk[K], Chunk[V]))(implicit codec: Codec): Chunk[RespValue.BulkString] = { + val (lua, keys, args) = data + val encodedScript = Chunk(encodeString(lua), encodeString(keys.size.toString)) + val encodedKeys = keys.flatMap(inputK.encode) + val encodedArgs = args.flatMap(inputV.encode) + encodedScript ++ encodedKeys ++ encodedArgs + } + } + + case object ScriptDebugInput extends Input[DebugMode] { + def encode(data: DebugMode)(implicit codec: Codec): Chunk[RespValue.BulkString] = + Chunk.single(encodeString(data.stringify)) + } + case object WithScoresInput extends Input[WithScores] { def encode(data: WithScores)(implicit codec: Codec): Chunk[RespValue.BulkString] = Chunk.single(encodeString(data.stringify)) diff --git a/redis/src/main/scala/zio/redis/Output.scala b/redis/src/main/scala/zio/redis/Output.scala index 78d12d2..c332ce2 100644 --- a/redis/src/main/scala/zio/redis/Output.scala +++ b/redis/src/main/scala/zio/redis/Output.scala @@ -18,6 +18,10 @@ sealed trait Output[+A] { throw RedisError.BusyGroup(msg.drop(9).trim) case RespValue.Error(msg) if msg.startsWith("NOGROUP") => throw RedisError.NoGroup(msg.drop(7).trim) + case RespValue.Error(msg) if msg.startsWith("NOSCRIPT") => + throw RedisError.NoScript(msg.drop(8).trim) + case RespValue.Error(msg) if msg.startsWith("NOTBUSY") => + throw RedisError.NotBusy(msg.drop(7).trim) case RespValue.Error(msg) => throw RedisError.ProtocolError(msg.trim) case success => @@ -37,6 +41,12 @@ object Output { import RedisError._ + def apply[A](implicit output: Output[A]): Output[A] = output + + case object RespValueOutput extends Output[RespValue] { + protected def tryDecode(respValue: RespValue)(implicit codec: Codec): RespValue = respValue + } + case object BoolOutput extends Output[Boolean] { protected def tryDecode(respValue: RespValue)(implicit codec: Codec): Boolean = respValue match { diff --git a/redis/src/main/scala/zio/redis/RedisError.scala b/redis/src/main/scala/zio/redis/RedisError.scala index 5427af9..e22cac3 100644 --- a/redis/src/main/scala/zio/redis/RedisError.scala +++ b/redis/src/main/scala/zio/redis/RedisError.scala @@ -16,5 +16,7 @@ object RedisError { final case class WrongType(message: String) extends RedisError final case class BusyGroup(message: String) extends RedisError final case class NoGroup(message: String) extends RedisError + final case class NoScript(message: String) extends RedisError + final case class NotBusy(message: String) extends RedisError final case class IOError(exception: IOException) extends RedisError } diff --git a/redis/src/main/scala/zio/redis/ResultBuilder.scala b/redis/src/main/scala/zio/redis/ResultBuilder.scala index b90e043..254ffd1 100644 --- a/redis/src/main/scala/zio/redis/ResultBuilder.scala +++ b/redis/src/main/scala/zio/redis/ResultBuilder.scala @@ -27,4 +27,8 @@ object ResultBuilder { trait ResultBuilder3[+F[_, _, _]] extends ResultBuilder { def returning[R1: Schema, R2: Schema, R3: Schema]: ZIO[RedisExecutor, RedisError, F[R1, R2, R3]] } + + trait ResultOutputBuilder extends ResultBuilder { + def returning[R: Output]: ZIO[RedisExecutor, RedisError, R] + } } diff --git a/redis/src/main/scala/zio/redis/api/Scripting.scala b/redis/src/main/scala/zio/redis/api/Scripting.scala new file mode 100644 index 0000000..324b0f7 --- /dev/null +++ b/redis/src/main/scala/zio/redis/api/Scripting.scala @@ -0,0 +1,119 @@ +package zio.redis.api + +import zio.{Chunk, ZIO} +import zio.redis._ +import zio.redis.Input._ +import zio.redis.Output._ +import zio.redis.ResultBuilder.ResultOutputBuilder + +trait Scripting { + import Scripting._ + + /** + * Evaluates a Lua script + * + * Example of custom data: + * {{{ + * final case class Person(name: String, age: Long) + * + * val encoder = Input.Tuple2(StringInput, LongInput).contramap[Person] { civ => + * (civ.name, civ.age) + * } + * val decoder = RespValueOutput.map { + * case RespValue.Array(elements) => + * val name = elements(0) match { + * case s @ RespValue.BulkString(_) => s.asString + * case other => throw ProtocolError(s"$other isn't a string type") + * } + * val age = elements(1) match { + * case RespValue.Integer(value) => value + * case other => throw ProtocolError(s"$other isn't a integer type") + * } + * Person(name, age) + * case other => throw ProtocolError(s"$other isn't an array type") + * } + * }}} + * + * @param script + * Lua script + * @param keys + * keys available through KEYS param in the script + * @param args + * values available through ARGV param in the script + * @return + * redis protocol value that is converted from the Lua type. You have to write decoder that would convert redis + * protocol value to a suitable type for your app + */ + def eval[K: Input, A: Input]( + script: String, + keys: Chunk[K], + args: Chunk[A] + ): ResultOutputBuilder = new ResultOutputBuilder { + def returning[R: Output]: ZIO[RedisExecutor, RedisError, R] = { + val command = RedisCommand(Eval, EvalInput(Input[K], Input[A]), Output[R]) + command.run((script, keys, args)) + } + } + + /** + * Evaluates a Lua script cached on the server side by its SHA1 digest. Scripts could be cached using the + * [[zio.redis.api.Scripting.scriptLoad]] method. + * + * @param sha1 + * SHA1 digest + * @param keys + * keys available through KEYS param in the script + * @param args + * values available through ARGV param in the script + * @return + * redis protocol value that is converted from the Lua type. You have to write decoder that would convert redis + * protocol value to a suitable type for your app + */ + def evalSha[K: Input, A: Input]( + sha1: String, + keys: Chunk[K], + args: Chunk[A] + ): ResultOutputBuilder = new ResultOutputBuilder { + def returning[R: Output]: ZIO[RedisExecutor, RedisError, R] = { + val command = RedisCommand(EvalSha, EvalInput(Input[K], Input[A]), Output[R]) + command.run((sha1, keys, args)) + } + } + + /** + * Checks existence of the scripts in the script cache. + * + * @param sha1 + * one required SHA1 digest + * @param sha1s + * maybe rest of the SHA1 digests + * @return + * for every corresponding SHA1 digest of a script that actually exists in the script cache, an true is returned, + * otherwise false is returned. + */ + def scriptExists(sha1: String, sha1s: String*): ZIO[RedisExecutor, RedisError, Chunk[Boolean]] = { + val command = RedisCommand(ScriptExists, NonEmptyList(StringInput), ChunkOutput(BoolOutput)) + command.run((sha1, sha1s.toList)) + } + + /** + * Loads a script into the scripts cache. After the script is loaded into the script cache it could be evaluated using + * the [[zio.redis.api.Scripting.evalSha]] method. + * + * @param script + * Lua script + * @return + * the SHA1 digest of the script added into the script cache. + */ + def scriptLoad(script: String): ZIO[RedisExecutor, RedisError, String] = { + val command = RedisCommand(ScriptLoad, StringInput, MultiStringOutput) + command.run(script) + } +} + +private[redis] object Scripting { + final val Eval = "EVAL" + final val EvalSha = "EVALSHA" + final val ScriptExists = "SCRIPT EXISTS" + final val ScriptLoad = "SCRIPT LOAD" +} diff --git a/redis/src/main/scala/zio/redis/options/Scripting.scala b/redis/src/main/scala/zio/redis/options/Scripting.scala new file mode 100644 index 0000000..b710e6a --- /dev/null +++ b/redis/src/main/scala/zio/redis/options/Scripting.scala @@ -0,0 +1,18 @@ +package zio.redis.options + +trait Scripting { + sealed trait DebugMode { self => + private[redis] final def stringify: String = + self match { + case DebugMode.Yes => "YES" + case DebugMode.Sync => "SYNC" + case DebugMode.No => "NO" + } + } + + object DebugMode { + case object Yes extends DebugMode + case object Sync extends DebugMode + case object No extends DebugMode + } +} diff --git a/redis/src/main/scala/zio/redis/package.scala b/redis/src/main/scala/zio/redis/package.scala index 9829271..37a21c7 100644 --- a/redis/src/main/scala/zio/redis/package.scala +++ b/redis/src/main/scala/zio/redis/package.scala @@ -11,6 +11,7 @@ package object redis with api.Strings with api.SortedSets with api.Streams + with api.Scripting with options.Connection with options.Geo with options.Keys @@ -18,7 +19,8 @@ package object redis with options.SortedSets with options.Strings with options.Lists - with options.Streams { + with options.Streams + with options.Scripting { type Id[+A] = A type RedisExecutor = Has[RedisExecutor.Service] diff --git a/redis/src/test/scala/zio/redis/ApiSpec.scala b/redis/src/test/scala/zio/redis/ApiSpec.scala index 5077d39..f54fdd5 100644 --- a/redis/src/test/scala/zio/redis/ApiSpec.scala +++ b/redis/src/test/scala/zio/redis/ApiSpec.scala @@ -17,7 +17,8 @@ object ApiSpec with GeoSpec with HyperLogLogSpec with HashSpec - with StreamsSpec { + with StreamsSpec + with ScriptingSpec { def spec: ZSpec[TestEnvironment, Failure] = suite("Redis commands")( @@ -31,7 +32,8 @@ object ApiSpec geoSuite, hyperLogLogSuite, hashSuite, - streamsSuite + streamsSuite, + scriptingSpec ).provideCustomLayerShared((Logging.ignore ++ ZLayer.succeed(codec) >>> RedisExecutor.local.orDie) ++ Clock.live) @@ sequential, suite("Test Executor")( diff --git a/redis/src/test/scala/zio/redis/OutputSpec.scala b/redis/src/test/scala/zio/redis/OutputSpec.scala index 6b2a0b2..3b7ab31 100644 --- a/redis/src/test/scala/zio/redis/OutputSpec.scala +++ b/redis/src/test/scala/zio/redis/OutputSpec.scala @@ -42,13 +42,9 @@ object OutputSpec extends BaseSpec { } yield assert(res)(isEmpty) }, testM("extract non-empty arrays") { + val respValue = RespValue.array(RespValue.bulkString("foo"), RespValue.bulkString("bar")) for { - res <- - Task( - ChunkOutput(MultiStringOutput).unsafeDecode( - RespValue.array(RespValue.bulkString("foo"), RespValue.bulkString("bar")) - ) - ) + res <- Task(ChunkOutput(MultiStringOutput).unsafeDecode(respValue)) } yield assert(res)(hasSameElements(Chunk("foo", "bar"))) } ), diff --git a/redis/src/test/scala/zio/redis/ScriptingSpec.scala b/redis/src/test/scala/zio/redis/ScriptingSpec.scala new file mode 100644 index 0000000..6591bf8 --- /dev/null +++ b/redis/src/test/scala/zio/redis/ScriptingSpec.scala @@ -0,0 +1,248 @@ +package zio.redis + +import zio._ +import zio.redis.Input.{BoolInput, ByteInput, LongInput, StringInput} +import zio.redis.Output._ +import zio.redis.RedisError._ +import zio.redis.ScriptingSpec._ +import zio.test._ +import zio.test.Assertion._ + +import scala.util.Random + +trait ScriptingSpec extends BaseSpec { + val scriptingSpec: Spec[Annotations with RedisExecutor, TestFailure[Any], TestSuccess] = + suite("scripting")( + suite("eval")( + testM("put boolean and return existence of key") { + for { + key <- uuid + arg = true + lua = + """ + |redis.call('set',KEYS[1],ARGV[1]) + |return redis.call('exists',KEYS[1]) + """.stripMargin + res <- eval(lua, Chunk(key), Chunk(arg)).returning[Boolean] + } yield assertTrue(res) + }, + testM("take strings return strings") { + for { + key1 <- uuid + key2 <- uuid + arg1 <- uuid + arg2 <- uuid + lua = """return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}""" + res <- eval(lua, Chunk(key1, key2), Chunk(arg1, arg2)).returning[Chunk[String]] + } yield assertTrue(res == Chunk(key1, key2, arg1, arg2)) + }, + testM("put custom input value return custom input value") { + for { + key1 <- uuid + key2 <- uuid + arg1 <- uuid + arg2 <- ZIO.succeedNow(Random.nextLong()) + arg = CustomInputValue(arg1, arg2) + lua = """return {ARGV[1],ARGV[2]}""" + res <- eval(lua, Chunk(key1, key2), Chunk(arg)).returning[Map[String, String]] + } yield assertTrue(res == Map(arg1 -> arg2.toString)) + }, + testM("return custom data type") { + val lua = """return {1,2,{3,'Hello World!'}}""" + val expected = CustomData(1, 2, (3, "Hello World!")) + val emptyInput: Chunk[Long] = Chunk.empty + for { + res <- eval(lua, emptyInput, emptyInput).returning[CustomData] + } yield assertTrue(res == expected) + }, + testM("throw an error when incorrect script's sent") { + for { + key <- uuid + arg <- uuid + lua = ";" + error = "Error compiling script (new function): user_script:1: unexpected symbol near ';'" + res <- eval(lua, Chunk(key), Chunk(arg)).returning[String].either + } yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(error))))) + }, + testM("throw an error if couldn't decode resp value") { + val customError = "custom error" + implicit val decoder: Output[String] = errorOutput(customError) + for { + key <- uuid + arg <- uuid + lua = "" + res <- eval(lua, Chunk(key), Chunk(arg)).returning[String](decoder).either + } yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(customError))))) + }, + testM("throw custom error from script") { + for { + key <- uuid + arg <- uuid + myError = "My Error" + lua = s"""return redis.error_reply("${myError}")""" + res <- eval(lua, Chunk(key), Chunk(arg)).returning[String].either + } yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(myError))))) + } + ), + suite("evalSHA")( + testM("put boolean and return existence of key") { + for { + key <- uuid + arg = true + lua = + """ + |redis.call('set',KEYS[1],ARGV[1]) + |return redis.call('exists',KEYS[1]) + """.stripMargin + sha <- scriptLoad(lua) + res <- evalSha(sha, Chunk(key), Chunk(arg)).returning[Boolean] + } yield assertTrue(res) + }, + testM("take strings return strings") { + for { + key1 <- uuid + key2 <- uuid + arg1 <- uuid + arg2 <- uuid + lua = """return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}""" + sha <- scriptLoad(lua) + res <- evalSha(sha, Chunk(key1, key2), Chunk(arg1, arg2)).returning[Chunk[String]] + } yield assertTrue(res == Chunk(key1, key2, arg1, arg2)) + }, + testM("return custom data type") { + val lua = """return {1,2,{3,'Hello World!'}}""" + val expected = CustomData(1, 2, (3, "Hello World!")) + val emptyInput: Chunk[String] = Chunk.empty + for { + res <- eval(lua, emptyInput, emptyInput).returning[CustomData] + } yield assertTrue(res == expected) + }, + testM("throw an error if couldn't decode resp value") { + val customError = "custom error" + implicit val decoder: Output[String] = errorOutput(customError) + for { + key <- uuid + arg <- uuid + lua = "" + sha <- scriptLoad(lua) + res <- evalSha(sha, Chunk(key), Chunk(arg)).returning[String](decoder).either + } yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(customError))))) + }, + testM("throw custom error from script") { + for { + key <- uuid + arg <- uuid + myError = "My Error" + lua = s"""return redis.error_reply("${myError}")""" + sha <- scriptLoad(lua) + res <- evalSha(sha, Chunk(key), Chunk(arg)).returning[String].either + } yield assert(res)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(myError))))) + }, + testM("throw NoScript error if script isn't found in cache") { + val lua = """return "1"""" + val error = "No matching script. Please use EVAL." + val emptyInput: Chunk[String] = Chunk.empty + for { + res <- evalSha(lua, emptyInput, emptyInput).returning[String].either + } yield assert(res)(isLeft(isSubtype[NoScript](hasField("message", _.message, equalTo(error))))) + } + ), + suite("scriptExists")( + testM("return true if scripts are found in the cache") { + val lua1 = """return "1"""" + val lua2 = """return "2"""" + for { + sha1 <- scriptLoad(lua1) + sha2 <- scriptLoad(lua2) + res <- scriptExists(sha1, sha2) + } yield assertTrue(res == Chunk(true, true)) + }, + testM("return false if scripts aren't found in the cache") { + val lua1 = """return "1"""" + val lua2 = """return "2"""" + for { + res <- scriptExists(lua1, lua2) + } yield assertTrue(res == Chunk(false, false)) + } + ), + suite("scriptLoad")( + testM("return OK") { + val lua = """return "1"""" + for { + sha <- scriptLoad(lua) + } yield assert(sha)(isSubtype[String](anything)) + }, + testM("throw an error when incorrect script was sent") { + val lua = ";" + val error = "Error compiling script (new function): user_script:1: unexpected symbol near ';'" + for { + sha <- scriptLoad(lua).either + } yield assert(sha)(isLeft(isSubtype[ProtocolError](hasField("message", _.message, equalTo(error))))) + } + ) + ) +} + +object ScriptingSpec { + + final case class CustomInputValue(name: String, age: Long) + + object CustomInputValue { + implicit val encoder: Input[CustomInputValue] = Input.Tuple2(StringInput, LongInput).contramap { civ => + (civ.name, civ.age) + } + } + + final case class CustomData(count: Long, avg: Long, pair: (Int, String)) + + object CustomData { + implicit val decoder: Output[CustomData] = RespValueOutput.map { + case RespValue.Array(elements) => + val count = RespToLong(elements(0)) + val avg = RespToLong(elements(1)) + val pair = elements(2) match { + case RespValue.Array(elements) => (RespToLong(elements(0)).toInt, RespToString(elements(1))) + case other => throw ProtocolError(s"$other isn't an array type") + } + CustomData(count, avg, pair) + case other => throw ProtocolError(s"$other isn't an array type") + } + } + + private val RespToLong: RespValue => Long = { + case RespValue.Integer(value) => value + case other => throw ProtocolError(s"$other isn't a integer type") + } + private val RespToString: RespValue => String = { + case s @ RespValue.BulkString(_) => s.asString + case other => throw ProtocolError(s"$other isn't a string type") + } + + implicit val bytesEncoder: Input[Chunk[Byte]] = ByteInput + implicit val booleanInput: Input[Boolean] = BoolInput + implicit val stringInput: Input[String] = StringInput + implicit val longInput: Input[Long] = LongInput + + implicit val keyValueOutput: KeyValueOutput[String, String] = KeyValueOutput(MultiStringOutput, MultiStringOutput) + implicit val booleanOutput: Output[Boolean] = RespValueOutput.map { + case RespValue.Integer(0) => false + case RespValue.Integer(1) => true + case other => throw ProtocolError(s"$other isn't a string nor an array") + } + implicit val simpleStringOutput: Output[String] = RespValueOutput.map { + case RespValue.SimpleString(value) => value + case other => throw ProtocolError(s"$other isn't a string nor an array") + } + implicit val chunkStringOutput: Output[Chunk[String]] = RespValueOutput.map { + case RespValue.Array(elements) => + elements.map { + case s @ RespValue.BulkString(_) => s.asString + case other => throw ProtocolError(s"$other isn't a bulk string") + } + case other => throw ProtocolError(s"$other isn't a string nor an array") + } + + def errorOutput(error: String): Output[String] = RespValueOutput.map { _ => + throw ProtocolError(error) + } +} -- GitLab