未验证 提交 c9c85f11 编写于 作者: M marregui 提交者: GitHub

feat(sql): add SHOW commands for ACL (#3631)

上级 cdf19fc9
...@@ -109,6 +109,22 @@ public interface SecurityContext { ...@@ -109,6 +109,22 @@ public interface SecurityContext {
void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames); void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames);
void authorizeShowUser(CharSequence userName);
void authorizeShowUsers();
void authorizeShowGroups();
void authorizeShowGroups(CharSequence userName);
void authorizeShowServiceAccount(CharSequence serviceAccountName);
void authorizeShowServiceAccounts();
void authorizeShowServiceAccounts(CharSequence userOrGroupName);
void authorizeShowPermissions(CharSequence entityName);
void authorizeTableBackup(ObjHashSet<TableToken> tableTokens); void authorizeTableBackup(ObjHashSet<TableToken> tableTokens);
void authorizeTableCreate(); void authorizeTableCreate();
......
...@@ -182,6 +182,38 @@ public class AllowAllSecurityContext implements SecurityContext { ...@@ -182,6 +182,38 @@ public class AllowAllSecurityContext implements SecurityContext {
public void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames) { public void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames) {
} }
@Override
public void authorizeShowUser(CharSequence userName) {
}
@Override
public void authorizeShowUsers() {
}
@Override
public void authorizeShowGroups() {
}
@Override
public void authorizeShowGroups(CharSequence userName) {
}
@Override
public void authorizeShowServiceAccount(CharSequence serviceAccountName) {
}
@Override
public void authorizeShowServiceAccounts() {
}
@Override
public void authorizeShowServiceAccounts(CharSequence userOrGroupName) {
}
@Override
public void authorizeShowPermissions(CharSequence entityName) {
}
@Override @Override
public void authorizeTableBackup(ObjHashSet<TableToken> tableTokens) { public void authorizeTableBackup(ObjHashSet<TableToken> tableTokens) {
} }
......
...@@ -36,4 +36,44 @@ public class DenyAllSecurityContext extends ReadOnlySecurityContext { ...@@ -36,4 +36,44 @@ public class DenyAllSecurityContext extends ReadOnlySecurityContext {
public void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames) { public void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames) {
throw CairoException.nonCritical().put("permission denied"); throw CairoException.nonCritical().put("permission denied");
} }
@Override
public void authorizeShowUser(CharSequence userName) {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowUsers() {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowGroups() {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowGroups(CharSequence userName) {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowServiceAccount(CharSequence serviceAccountName) {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowServiceAccounts() {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowServiceAccounts(CharSequence userOrGroupName) {
throw CairoException.nonCritical().put("permission denied");
}
@Override
public void authorizeShowPermissions(CharSequence entityName) {
throw CairoException.nonCritical().put("permission denied");
}
} }
...@@ -218,6 +218,38 @@ public class ReadOnlySecurityContext implements SecurityContext { ...@@ -218,6 +218,38 @@ public class ReadOnlySecurityContext implements SecurityContext {
public void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames) { public void authorizeSelect(TableToken tableToken, @NotNull ObjList<CharSequence> columnNames) {
} }
@Override
public void authorizeShowUser(CharSequence userName) {
}
@Override
public void authorizeShowUsers() {
}
@Override
public void authorizeShowGroups() {
}
@Override
public void authorizeShowGroups(CharSequence userName) {
}
@Override
public void authorizeShowServiceAccount(CharSequence serviceAccountName) {
}
@Override
public void authorizeShowServiceAccounts() {
}
@Override
public void authorizeShowServiceAccounts(CharSequence userOrGroupName) {
}
@Override
public void authorizeShowPermissions(CharSequence entityName) {
}
@Override @Override
public void authorizeTableBackup(ObjHashSet<TableToken> tableTokens) { public void authorizeTableBackup(ObjHashSet<TableToken> tableTokens) {
throw CairoException.authorization().put("Write permission denied").setCacheable(true); throw CairoException.authorization().put("Write permission denied").setCacheable(true);
......
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2023 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.test.cairo;
import io.questdb.cairo.SecurityContext;
import io.questdb.cairo.TableToken;
import io.questdb.cairo.security.AllowAllSecurityContext;
import io.questdb.cairo.security.DenyAllSecurityContext;
import io.questdb.cairo.security.ReadOnlySecurityContext;
import io.questdb.std.LongList;
import io.questdb.std.ObjHashSet;
import io.questdb.std.ObjList;
import org.junit.Assert;
import org.junit.Test;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
public class SecurityContextTest {
private static final Object[] NO_PARAM_ARGS = {};
private static final ObjList<CharSequence> columns = new ObjList<>();
private static final LongList permissions = new LongList();
private final static String tableName = "tab";
private static final Object[] THREE_PARAM_ARGS = {permissions, tableName, columns};
private static final TableToken tableToken = new TableToken(tableName, tableName, 0, false);
private static final Object[] ONE_PARAM_ARGS = {tableToken};
private static final Object[] TWO_PARAM_ARGS = {tableToken, columns};
@Test
public void testAllowAllSecurityContext() throws InvocationTargetException, IllegalAccessException {
SecurityContext sc = AllowAllSecurityContext.INSTANCE;
for (Method method : SecurityContext.class.getMethods()) {
String name = method.getName();
if (name.startsWith("authorize")) {
Class<?>[] parameters = method.getParameterTypes();
switch (parameters.length) {
case 0:
method.invoke(sc, NO_PARAM_ARGS);
break;
case 1:
if (name.equals("authorizeCopyCancel")) {
method.invoke(sc, sc);
} else if (name.equals("authorizeTableBackup")) {
method.invoke(sc, new ObjHashSet<CharSequence>());
} else if (name.startsWith("authorizeShow")) {
method.invoke(sc, "userName");
} else {
method.invoke(sc, ONE_PARAM_ARGS);
}
break;
case 2:
method.invoke(sc, TWO_PARAM_ARGS);
break;
case 3:
method.invoke(sc, THREE_PARAM_ARGS);
break;
default:
throw new IndexOutOfBoundsException();
}
}
}
}
@Test
public void testDenyAllSecurityContext() throws IllegalAccessException {
SecurityContext sc = DenyAllSecurityContext.INSTANCE;
for (Method method : SecurityContext.class.getMethods()) {
String name = method.getName();
if (name.startsWith("authorize")) {
Class<?>[] parameters = method.getParameterTypes();
try {
switch (parameters.length) {
case 0:
method.invoke(sc, NO_PARAM_ARGS);
Assert.fail();
break;
case 1:
if (name.equals("authorizeCopyCancel")) {
method.invoke(sc, sc);
} else if (name.equals("authorizeTableBackup")) {
method.invoke(sc, new ObjHashSet<CharSequence>());
} else if (name.startsWith("authorizeShow")) {
method.invoke(sc, "userName");
} else {
method.invoke(sc, ONE_PARAM_ARGS);
}
Assert.fail();
break;
case 2:
method.invoke(sc, TWO_PARAM_ARGS);
Assert.fail();
break;
case 3:
method.invoke(sc, THREE_PARAM_ARGS);
Assert.fail();
default:
throw new IndexOutOfBoundsException();
}
} catch (InvocationTargetException err) {
Assert.assertTrue(err.getTargetException().getMessage().contains("permission denied"));
}
}
}
}
@Test
public void testReadOnlySecurityContext() throws IllegalAccessException {
SecurityContext sc = ReadOnlySecurityContext.INSTANCE;
for (Method method : SecurityContext.class.getMethods()) {
String name = method.getName();
if (name.startsWith("authorize")) {
Class<?>[] parameters = method.getParameterTypes();
try {
switch (parameters.length) {
case 0:
method.invoke(sc, NO_PARAM_ARGS);
if (name.startsWith("authorizeShow")) {
continue;
}
Assert.fail();
break;
case 1:
if (name.equals("authorizeCopyCancel")) {
method.invoke(sc, sc);
} else if (name.equals("authorizeTableBackup")) {
method.invoke(sc, new ObjHashSet<CharSequence>());
} else if (name.startsWith("authorizeShow")) {
method.invoke(sc, "userName");
} else {
method.invoke(sc, ONE_PARAM_ARGS);
}
if (name.startsWith("authorizeShow")) {
continue;
}
Assert.fail();
break;
case 2:
method.invoke(sc, TWO_PARAM_ARGS);
if (name.equals("authorizeSelect")) {
continue;
}
Assert.fail();
break;
case 3:
method.invoke(sc, THREE_PARAM_ARGS);
Assert.fail();
break;
default:
throw new IndexOutOfBoundsException();
}
} catch (InvocationTargetException err) {
Assert.assertTrue(err.getTargetException().getMessage().contains("permission denied"));
}
}
}
}
}
...@@ -69,19 +69,18 @@ public abstract class BasePGTest extends AbstractCairoTest { ...@@ -69,19 +69,18 @@ public abstract class BasePGTest extends AbstractCairoTest {
assertResultSet(null, expected, sink, rs); assertResultSet(null, expected, sink, rs);
} }
public static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs) throws SQLException, IOException { public static PGWireServer createPGWireServer(
printToSink(sink, rs, null); PGWireConfiguration configuration,
TestUtils.assertEquals(message, expected, sink); CairoEngine cairoEngine,
} WorkerPool workerPool,
PGWireServer.PGConnectionContextFactory contextFactory,
public static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs, IntIntHashMap map) throws SQLException, IOException { CircuitBreakerRegistry registry
printToSink(sink, rs, map); ) {
TestUtils.assertEquals(message, expected, sink); if (!configuration.isEnabled()) {
} return null;
}
public static void assertResultSet(CharSequence expected, StringSink sink, ResultSet rs, IntIntHashMap map) throws SQLException, IOException { return new PGWireServer(configuration, cairoEngine, workerPool, contextFactory, registry);
printToSink(sink, rs, map);
TestUtils.assertEquals(null, expected, sink);
} }
public static PGWireServer createPGWireServer( public static PGWireServer createPGWireServer(
...@@ -110,20 +109,6 @@ public abstract class BasePGTest extends AbstractCairoTest { ...@@ -110,20 +109,6 @@ public abstract class BasePGTest extends AbstractCairoTest {
); );
} }
public static PGWireServer createPGWireServer(
PGWireConfiguration configuration,
CairoEngine cairoEngine,
WorkerPool workerPool,
PGWireServer.PGConnectionContextFactory contextFactory,
CircuitBreakerRegistry registry
) {
if (!configuration.isEnabled()) {
return null;
}
return new PGWireServer(configuration, cairoEngine, workerPool, contextFactory, registry);
}
private static void toSink(InputStream is, CharSink sink) throws IOException { private static void toSink(InputStream is, CharSink sink) throws IOException {
// limit what we print // limit what we print
byte[] bb = new byte[1]; byte[] bb = new byte[1];
...@@ -159,7 +144,21 @@ public abstract class BasePGTest extends AbstractCairoTest { ...@@ -159,7 +144,21 @@ public abstract class BasePGTest extends AbstractCairoTest {
} }
} }
protected static long printToSink(StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException { protected static void assertResultSet(CharSequence expected, StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
assertResultSet(null, expected, sink, rs, map);
}
protected static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
printToSink(sink, rs, map);
TestUtils.assertEquals(message, expected, sink);
}
protected static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs) throws SQLException, IOException {
printToSink(sink, rs, null);
TestUtils.assertEquals(message, expected, sink);
}
public static long printToSink(StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
// dump metadata // dump metadata
ResultSetMetaData metaData = rs.getMetaData(); ResultSetMetaData metaData = rs.getMetaData();
final int columnCount = metaData.getColumnCount(); final int columnCount = metaData.getColumnCount();
...@@ -185,6 +184,7 @@ public abstract class BasePGTest extends AbstractCairoTest { ...@@ -185,6 +184,7 @@ public abstract class BasePGTest extends AbstractCairoTest {
} }
sink.put('\n'); sink.put('\n');
Timestamp timestamp;
long rows = 0; long rows = 0;
while (rs.next()) { while (rs.next()) {
rows++; rows++;
...@@ -219,7 +219,7 @@ public abstract class BasePGTest extends AbstractCairoTest { ...@@ -219,7 +219,7 @@ public abstract class BasePGTest extends AbstractCairoTest {
} }
break; break;
case TIMESTAMP: case TIMESTAMP:
Timestamp timestamp = rs.getTimestamp(i); timestamp = rs.getTimestamp(i);
if (timestamp == null) { if (timestamp == null) {
sink.put("null"); sink.put("null");
} else { } else {
......
...@@ -39,9 +39,8 @@ import static io.questdb.griffin.SqlKeywords.*; ...@@ -39,9 +39,8 @@ import static io.questdb.griffin.SqlKeywords.*;
public class SqlKeywordsTest { public class SqlKeywordsTest {
@Test protected static final Map<String, String> specialCases = new HashMap<>();
public void testIs() throws Exception { static {
Map<String, String> specialCases = new HashMap<>();
specialCases.put("isColonColon", "::"); specialCases.put("isColonColon", "::");
specialCases.put("isConcatOperator", "||"); specialCases.put("isConcatOperator", "||");
specialCases.put("isMaxIdentifierLength", "max_identifier_length"); specialCases.put("isMaxIdentifierLength", "max_identifier_length");
...@@ -51,20 +50,28 @@ public class SqlKeywordsTest { ...@@ -51,20 +50,28 @@ public class SqlKeywordsTest {
specialCases.put("isStandardConformingStrings", "standard_conforming_strings"); specialCases.put("isStandardConformingStrings", "standard_conforming_strings");
specialCases.put("isTextArray", "text[]"); specialCases.put("isTextArray", "text[]");
specialCases.put("isTransactionIsolation", "transaction_isolation"); specialCases.put("isTransactionIsolation", "transaction_isolation");
}
@Test
public void testIs() throws Exception {
testIs(SqlKeywords.class.getMethods());
}
Method[] methods = SqlKeywords.class.getMethods(); protected static void testIs(Method[] methods) throws Exception {
Arrays.sort(methods, Comparator.comparing(Method::getName)); Arrays.sort(methods, Comparator.comparing(Method::getName));
for (Method method : methods) { for (Method method : methods) {
String name; String name;
int m = method.getModifiers() & Modifier.methodModifiers(); int m = method.getModifiers() & Modifier.methodModifiers();
if (Modifier.isPublic(m) && Modifier.isStatic(m) && (name = method.getName()).startsWith("is")) { if (Modifier.isPublic(m) && Modifier.isStatic(m) && (name = method.getName()).startsWith("is")) {
String keyword; String methodParam = specialCases.get(name);
if (name.endsWith("Keyword")) { if (methodParam == null) {
keyword = name.substring(2, name.length() - 7).toLowerCase(); if (!name.endsWith("Keyword")) {
} else { Assert.fail("if method name does not end with 'Keyword', it has to be a special case");
keyword = specialCases.get(name); }
methodParam = name.substring(2, name.length() - 7).toLowerCase();
} }
Assert.assertTrue((boolean) method.invoke(null, keyword)); Assert.assertTrue((boolean) method.invoke(null, methodParam));
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册