未验证 提交 fe56f1d6 编写于 作者: L Liang Zhang 提交者: GitHub

Decouple AuthenticationEngine and BackendConnection (#7038)

* Decouple AuthenticationEngine and BackendConnection

* Decouple AuthenticationEngine and BackendConnection
上级 5ee3490c
...@@ -69,7 +69,7 @@ public final class BackendConnection implements JDBCExecutionConnection, AutoClo ...@@ -69,7 +69,7 @@ public final class BackendConnection implements JDBCExecutionConnection, AutoClo
private int connectionId; private int connectionId;
@Setter @Setter
private String userName; private String username;
private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create(); private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
......
...@@ -53,7 +53,7 @@ public final class ShowDatabasesBackendHandler implements TextProtocolBackendHan ...@@ -53,7 +53,7 @@ public final class ShowDatabasesBackendHandler implements TextProtocolBackendHan
private Collection getSchemaNames() { private Collection getSchemaNames() {
Collection<String> result = new LinkedList<>(ProxySchemaContexts.getInstance().getSchemaNames()); Collection<String> result = new LinkedList<>(ProxySchemaContexts.getInstance().getSchemaNames());
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUserName()).getAuthorizedSchemas(); Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUsername()).getAuthorizedSchemas();
if (!authorizedSchemas.isEmpty()) { if (!authorizedSchemas.isEmpty()) {
result.retainAll(authorizedSchemas); result.retainAll(authorizedSchemas);
} }
......
...@@ -53,7 +53,7 @@ public final class UseDatabaseBackendHandler implements TextProtocolBackendHandl ...@@ -53,7 +53,7 @@ public final class UseDatabaseBackendHandler implements TextProtocolBackendHandl
} }
private boolean isAuthorizedSchema(final String schema) { private boolean isAuthorizedSchema(final String schema) {
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUserName()).getAuthorizedSchemas(); Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUsername()).getAuthorizedSchemas();
return authorizedSchemas.isEmpty() || authorizedSchemas.contains(schema); return authorizedSchemas.isEmpty() || authorizedSchemas.contains(schema);
} }
......
...@@ -56,7 +56,7 @@ public final class ShowDatabasesBackendHandlerTest { ...@@ -56,7 +56,7 @@ public final class ShowDatabasesBackendHandlerTest {
@SneakyThrows(ReflectiveOperationException.class) @SneakyThrows(ReflectiveOperationException.class)
public void setUp() { public void setUp() {
BackendConnection backendConnection = mock(BackendConnection.class); BackendConnection backendConnection = mock(BackendConnection.class);
when(backendConnection.getUserName()).thenReturn("root"); when(backendConnection.getUsername()).thenReturn("root");
showDatabasesBackendHandler = new ShowDatabasesBackendHandler(backendConnection); showDatabasesBackendHandler = new ShowDatabasesBackendHandler(backendConnection);
Field schemaContexts = ProxySchemaContexts.getInstance().getClass().getDeclaredField("schemaContexts"); Field schemaContexts = ProxySchemaContexts.getInstance().getClass().getDeclaredField("schemaContexts");
schemaContexts.setAccessible(true); schemaContexts.setAccessible(true);
......
...@@ -54,7 +54,7 @@ public class ShowTablesBackendHandlerTest { ...@@ -54,7 +54,7 @@ public class ShowTablesBackendHandlerTest {
@SneakyThrows(ReflectiveOperationException.class) @SneakyThrows(ReflectiveOperationException.class)
public void setUp() { public void setUp() {
BackendConnection backendConnection = mock(BackendConnection.class); BackendConnection backendConnection = mock(BackendConnection.class);
when(backendConnection.getUserName()).thenReturn("root"); when(backendConnection.getUsername()).thenReturn("root");
tablesBackendHandler = new ShowTablesBackendHandler("show tables", mock(SQLStatement.class), backendConnection); tablesBackendHandler = new ShowTablesBackendHandler("show tables", mock(SQLStatement.class), backendConnection);
Map<String, SchemaContext> schemaContextMap = getSchemaContextMap(); Map<String, SchemaContext> schemaContextMap = getSchemaContextMap();
when(backendConnection.getSchema()).thenReturn("schema_0"); when(backendConnection.getSchema()).thenReturn("schema_0");
......
...@@ -58,7 +58,7 @@ public final class UseDatabaseBackendHandlerTest { ...@@ -58,7 +58,7 @@ public final class UseDatabaseBackendHandlerTest {
@SneakyThrows(ReflectiveOperationException.class) @SneakyThrows(ReflectiveOperationException.class)
public void setUp() { public void setUp() {
backendConnection = mock(BackendConnection.class); backendConnection = mock(BackendConnection.class);
when(backendConnection.getUserName()).thenReturn("root"); when(backendConnection.getUsername()).thenReturn("root");
Field schemaContexts = ProxySchemaContexts.getInstance().getClass().getDeclaredField("schemaContexts"); Field schemaContexts = ProxySchemaContexts.getInstance().getClass().getDeclaredField("schemaContexts");
schemaContexts.setAccessible(true); schemaContexts.setAccessible(true);
schemaContexts.set(ProxySchemaContexts.getInstance(), schemaContexts.set(ProxySchemaContexts.getInstance(),
......
...@@ -78,9 +78,9 @@ public final class FrontendChannelInboundHandler extends ChannelInboundHandlerAd ...@@ -78,9 +78,9 @@ public final class FrontendChannelInboundHandler extends ChannelInboundHandlerAd
private boolean auth(final ChannelHandlerContext context, final ByteBuf message) { private boolean auth(final ChannelHandlerContext context, final ByteBuf message) {
try (PacketPayload payload = databaseProtocolFrontendEngine.getCodecEngine().createPacketPayload(message)) { try (PacketPayload payload = databaseProtocolFrontendEngine.getCodecEngine().createPacketPayload(message)) {
AuthenticationResult authResult = databaseProtocolFrontendEngine.getAuthEngine().auth(context, payload, backendConnection); AuthenticationResult authResult = databaseProtocolFrontendEngine.getAuthEngine().auth(context, payload);
if (authResult.isFinished()) { if (authResult.isFinished()) {
backendConnection.setUserName(authResult.getUsername()); backendConnection.setUsername(authResult.getUsername());
backendConnection.setCurrentSchema(authResult.getDatabase()); backendConnection.setCurrentSchema(authResult.getDatabase());
} }
return authResult.isFinished(); return authResult.isFinished();
......
...@@ -31,7 +31,6 @@ import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandsha ...@@ -31,7 +31,6 @@ import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandsha
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet; import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload; import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload; import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts; import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator; import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationEngine; import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationEngine;
...@@ -65,7 +64,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine { ...@@ -65,7 +64,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
} }
@Override @Override
public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload, final BackendConnection backendConnection) { public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload) {
if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == connectionPhase) { if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == connectionPhase) {
currentAuthResult = authPhaseFastPath(context, payload); currentAuthResult = authPhaseFastPath(context, payload);
if (!currentAuthResult.isFinished()) { if (!currentAuthResult.isFinished()) {
...@@ -75,13 +74,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine { ...@@ -75,13 +74,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
authenticationMethodMismatch((MySQLPacketPayload) payload); authenticationMethodMismatch((MySQLPacketPayload) payload);
} }
Optional<MySQLServerErrorCode> errorCode = authenticationHandler.login(currentAuthResult.getUsername(), authResponse, currentAuthResult.getDatabase()); Optional<MySQLServerErrorCode> errorCode = authenticationHandler.login(currentAuthResult.getUsername(), authResponse, currentAuthResult.getDatabase());
if (errorCode.isPresent()) { context.writeAndFlush(errorCode.isPresent() ? createErrorPacket(errorCode.get(), context) : new MySQLOKPacket(++sequenceId));
context.writeAndFlush(createMySQLErrPacket(errorCode.get(), context));
} else {
backendConnection.setCurrentSchema(currentAuthResult.getDatabase());
backendConnection.setUserName(currentAuthResult.getUsername());
context.writeAndFlush(new MySQLOKPacket(++sequenceId));
}
return AuthenticationResult.finished(currentAuthResult.getUsername(), currentAuthResult.getDatabase()); return AuthenticationResult.finished(currentAuthResult.getUsername(), currentAuthResult.getDatabase());
} }
...@@ -96,7 +89,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine { ...@@ -96,7 +89,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
if (isClientPluginAuth(packet) && !MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName().equals(packet.getAuthPluginName())) { if (isClientPluginAuth(packet) && !MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName().equals(packet.getAuthPluginName())) {
connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH; connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
context.writeAndFlush(new MySQLAuthSwitchRequestPacket(++sequenceId, MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName(), authenticationHandler.getAuthPluginData())); context.writeAndFlush(new MySQLAuthSwitchRequestPacket(++sequenceId, MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName(), authenticationHandler.getAuthPluginData()));
return AuthenticationResult.continued(); return AuthenticationResult.continued(packet.getUsername(), packet.getDatabase());
} }
return AuthenticationResult.finished(packet.getUsername(), packet.getDatabase()); return AuthenticationResult.finished(packet.getUsername(), packet.getDatabase());
} }
...@@ -111,7 +104,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine { ...@@ -111,7 +104,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
authResponse = packet.getAuthPluginResponse(); authResponse = packet.getAuthPluginResponse();
} }
private MySQLErrPacket createMySQLErrPacket(final MySQLServerErrorCode errorCode, final ChannelHandlerContext context) { private MySQLErrPacket createErrorPacket(final MySQLServerErrorCode errorCode, final ChannelHandlerContext context) {
return MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR == errorCode return MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR == errorCode
? new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR, currentAuthResult.getUsername(), getHostAddress(context), currentAuthResult.getDatabase()) ? new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR, currentAuthResult.getUsername(), getHostAddress(context), currentAuthResult.getDatabase())
: new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR, currentAuthResult.getUsername(), getHostAddress(context), getErrorMessage()); : new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR, currentAuthResult.getUsername(), getHostAddress(context), getErrorMessage());
......
...@@ -52,7 +52,7 @@ public final class MySQLComInitDbExecutor implements CommandExecutor { ...@@ -52,7 +52,7 @@ public final class MySQLComInitDbExecutor implements CommandExecutor {
} }
private boolean isAuthorizedSchema(final String schema) { private boolean isAuthorizedSchema(final String schema) {
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUserName()).getAuthorizedSchemas(); Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUsername()).getAuthorizedSchemas();
return authorizedSchemas.isEmpty() || authorizedSchemas.contains(schema); return authorizedSchemas.isEmpty() || authorizedSchemas.contains(schema);
} }
} }
...@@ -34,7 +34,6 @@ import org.apache.shardingsphere.kernel.context.SchemaContexts; ...@@ -34,7 +34,6 @@ import org.apache.shardingsphere.kernel.context.SchemaContexts;
import org.apache.shardingsphere.kernel.context.StandardSchemaContexts; import org.apache.shardingsphere.kernel.context.StandardSchemaContexts;
import org.apache.shardingsphere.kernel.context.runtime.RuntimeContext; import org.apache.shardingsphere.kernel.context.runtime.RuntimeContext;
import org.apache.shardingsphere.kernel.context.schema.ShardingSphereSchema; import org.apache.shardingsphere.kernel.context.schema.ShardingSphereSchema;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts; import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator; import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationResult; import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationResult;
...@@ -105,7 +104,7 @@ public final class MySQLProtocolFrontendEngineTest { ...@@ -105,7 +104,7 @@ public final class MySQLProtocolFrontendEngineTest {
ProxyUser proxyUser = new ProxyUser("", Collections.singleton("db1")); ProxyUser proxyUser = new ProxyUser("", Collections.singleton("db1"));
setAuthentication(proxyUser); setAuthentication(proxyUser);
when(payload.readStringNul()).thenReturn("root"); when(payload.readStringNul()).thenReturn("root");
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class)); AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload);
assertThat(actual.getUsername(), is("root")); assertThat(actual.getUsername(), is("root"));
assertNull(actual.getDatabase()); assertNull(actual.getDatabase());
assertTrue(actual.isFinished()); assertTrue(actual.isFinished());
...@@ -121,7 +120,7 @@ public final class MySQLProtocolFrontendEngineTest { ...@@ -121,7 +120,7 @@ public final class MySQLProtocolFrontendEngineTest {
when(payload.readStringNulByBytes()).thenReturn("root".getBytes()); when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 3307)); when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 3307));
when(context.channel()).thenReturn(channel); when(context.channel()).thenReturn(channel);
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class)); AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload);
assertThat(actual.getUsername(), is("root")); assertThat(actual.getUsername(), is("root"));
assertNull(actual.getDatabase()); assertNull(actual.getDatabase());
assertTrue(actual.isFinished()); assertTrue(actual.isFinished());
...@@ -138,12 +137,11 @@ public final class MySQLProtocolFrontendEngineTest { ...@@ -138,12 +137,11 @@ public final class MySQLProtocolFrontendEngineTest {
when(payload.readStringNulByBytes()).thenReturn("root".getBytes()); when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(context.channel()).thenReturn(channel); when(context.channel()).thenReturn(channel);
when(channel.remoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getByAddress(new byte[] {(byte) 192, (byte) 168, (byte) 0, (byte) 102}), 3307)); when(channel.remoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getByAddress(new byte[] {(byte) 192, (byte) 168, (byte) 0, (byte) 102}), 3307));
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class)); AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload);
assertThat(actual.getUsername(), is("root")); assertThat(actual.getUsername(), is("root"));
assertNull(actual.getDatabase()); assertNull(actual.getDatabase());
assertTrue(actual.isFinished()); assertTrue(actual.isFinished());
verify(context).writeAndFlush(argThat( verify(context).writeAndFlush(argThat((ArgumentMatcher<MySQLErrPacket>) argument -> "Access denied for user 'root'@'192.168.0.102' (using password: YES)".equals(argument.getErrorMessage())));
(ArgumentMatcher<MySQLErrPacket>) argument -> "Access denied for user 'root'@'192.168.0.102' (using password: YES)".equals(argument.getErrorMessage())));
} }
@SneakyThrows @SneakyThrows
...@@ -168,8 +166,7 @@ public final class MySQLProtocolFrontendEngineTest { ...@@ -168,8 +166,7 @@ public final class MySQLProtocolFrontendEngineTest {
} }
private SchemaContexts getSchemaContexts(final Authentication authentication) { private SchemaContexts getSchemaContexts(final Authentication authentication) {
return new StandardSchemaContexts(getSchemaContextMap(), authentication, return new StandardSchemaContexts(getSchemaContextMap(), authentication, new ConfigurationProperties(new Properties()), new MySQLDatabaseType());
new ConfigurationProperties(new Properties()), new MySQLDatabaseType());
} }
private Map<String, SchemaContext> getSchemaContextMap() { private Map<String, SchemaContext> getSchemaContextMap() {
......
...@@ -32,8 +32,8 @@ import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties ...@@ -32,8 +32,8 @@ import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties
import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType; import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType;
import org.apache.shardingsphere.kernel.context.SchemaContext; import org.apache.shardingsphere.kernel.context.SchemaContext;
import org.apache.shardingsphere.kernel.context.StandardSchemaContexts; import org.apache.shardingsphere.kernel.context.StandardSchemaContexts;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts; import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationResult;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
...@@ -85,7 +85,7 @@ public final class MySQLAuthenticationEngineTest { ...@@ -85,7 +85,7 @@ public final class MySQLAuthenticationEngineTest {
MySQLPacketPayload payload = mock(MySQLPacketPayload.class); MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class); ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class);
when(payload.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue()); when(payload.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class)); authenticationEngine.auth(channelHandlerContext, payload);
assertThat(getConnectionPhase(), is(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH)); assertThat(getConnectionPhase(), is(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH));
} }
...@@ -95,17 +95,25 @@ public final class MySQLAuthenticationEngineTest { ...@@ -95,17 +95,25 @@ public final class MySQLAuthenticationEngineTest {
MySQLPacketPayload payload = mock(MySQLPacketPayload.class); MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class); ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class);
when(payload.readStringEOFByBytes()).thenReturn(authResponse); when(payload.readStringEOFByBytes()).thenReturn(authResponse);
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class)); setAuthenticationResult();
authenticationEngine.auth(channelHandlerContext, payload);
assertThat(getAuthResponse(), is(authResponse)); assertThat(getAuthResponse(), is(authResponse));
} }
@SneakyThrows(ReflectiveOperationException.class)
private void setAuthenticationResult() {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("currentAuthResult");
field.setAccessible(true);
field.set(authenticationEngine, AuthenticationResult.continued("root", "sharding_db"));
}
@Test @Test
public void assertAuthWithLoginFail() throws NoSuchFieldException, IllegalAccessException { public void assertAuthWithLoginFail() throws NoSuchFieldException, IllegalAccessException {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH); setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ChannelHandlerContext context = getContext(); ChannelHandlerContext context = getContext();
setSchemas(); setSchemas();
when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.of(MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR)); when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.of(MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR));
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class)); authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse));
verify(context).writeAndFlush(any(MySQLErrPacket.class)); verify(context).writeAndFlush(any(MySQLErrPacket.class));
} }
...@@ -114,7 +122,7 @@ public final class MySQLAuthenticationEngineTest { ...@@ -114,7 +122,7 @@ public final class MySQLAuthenticationEngineTest {
ChannelHandlerContext context = getContext(); ChannelHandlerContext context = getContext();
setSchemas(); setSchemas();
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH); setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
authenticationEngine.auth(context, getPayload("root", "ABSENT DATABASE", authResponse), mock(BackendConnection.class)); authenticationEngine.auth(context, getPayload("root", "ABSENT DATABASE", authResponse));
verify(context).writeAndFlush(any(MySQLErrPacket.class)); verify(context).writeAndFlush(any(MySQLErrPacket.class));
} }
...@@ -124,7 +132,7 @@ public final class MySQLAuthenticationEngineTest { ...@@ -124,7 +132,7 @@ public final class MySQLAuthenticationEngineTest {
ChannelHandlerContext context = getContext(); ChannelHandlerContext context = getContext();
when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.empty()); when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.empty());
setSchemas(); setSchemas();
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class)); authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse));
verify(context).writeAndFlush(any(MySQLOKPacket.class)); verify(context).writeAndFlush(any(MySQLOKPacket.class));
} }
...@@ -136,10 +144,10 @@ public final class MySQLAuthenticationEngineTest { ...@@ -136,10 +144,10 @@ public final class MySQLAuthenticationEngineTest {
new Authentication(), new ConfigurationProperties(new Properties()), new MySQLDatabaseType())); new Authentication(), new ConfigurationProperties(new Properties()), new MySQLDatabaseType()));
} }
private MySQLPacketPayload getPayload(final String userName, final String database, final byte[] authResponse) { private MySQLPacketPayload getPayload(final String username, final String database, final byte[] authResponse) {
MySQLPacketPayload result = mock(MySQLPacketPayload.class); MySQLPacketPayload result = mock(MySQLPacketPayload.class);
when(result.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_CONNECT_WITH_DB.getValue()); when(result.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_CONNECT_WITH_DB.getValue());
when(result.readStringNul()).thenReturn(userName).thenReturn(database); when(result.readStringNul()).thenReturn(username).thenReturn(database);
when(result.readStringNulByBytes()).thenReturn(authResponse); when(result.readStringNulByBytes()).thenReturn(authResponse);
return result; return result;
} }
......
...@@ -32,7 +32,6 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.Postgre ...@@ -32,7 +32,6 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.Postgre
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator; import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLNegativePacket; import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLNegativePacket;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload; import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts; import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator; import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationEngine; import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationEngine;
...@@ -67,7 +66,7 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin ...@@ -67,7 +66,7 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
} }
@Override @Override
public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload, final BackendConnection backendConnection) { public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload) {
if (SSL_REQUEST_PAYLOAD_LENGTH == payload.getByteBuf().markReaderIndex().readInt() && SSL_REQUEST_CODE == payload.getByteBuf().readInt()) { if (SSL_REQUEST_PAYLOAD_LENGTH == payload.getByteBuf().markReaderIndex().readInt() && SSL_REQUEST_CODE == payload.getByteBuf().readInt()) {
context.writeAndFlush(new PostgreSQLSSLNegativePacket()); context.writeAndFlush(new PostgreSQLSSLNegativePacket());
return AuthenticationResult.continued(); return AuthenticationResult.continued();
...@@ -78,41 +77,35 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin ...@@ -78,41 +77,35 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
startupMessageReceived.set(true); startupMessageReceived.set(true);
String databaseName = comStartupPacket.getParametersMap().get(DATABASE_NAME_KEYWORD); String databaseName = comStartupPacket.getParametersMap().get(DATABASE_NAME_KEYWORD);
if (!Strings.isNullOrEmpty(databaseName) && !ProxySchemaContexts.getInstance().schemaExists(databaseName)) { if (!Strings.isNullOrEmpty(databaseName) && !ProxySchemaContexts.getInstance().schemaExists(databaseName)) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME, PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME, String.format("database \"%s\" does not exist", databaseName));
String.format("database \"%s\" does not exist", databaseName));
context.writeAndFlush(responsePacket); context.writeAndFlush(responsePacket);
context.close(); context.close();
return AuthenticationResult.continued(); return AuthenticationResult.continued();
} }
backendConnection.setCurrentSchema(databaseName); String username = comStartupPacket.getParametersMap().get(USER_NAME_KEYWORD);
String userName = comStartupPacket.getParametersMap().get(USER_NAME_KEYWORD); if (null == username || username.isEmpty()) {
if (null == userName || userName.isEmpty()) { PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, "user not set in StartupMessage");
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"user not set in StartupMessage");
context.writeAndFlush(responsePacket); context.writeAndFlush(responsePacket);
context.close(); context.close();
return AuthenticationResult.continued(); return AuthenticationResult.continued();
} }
backendConnection.setUserName(userName);
md5Salt = PostgreSQLRandomGenerator.getInstance().generateRandomBytes(4); md5Salt = PostgreSQLRandomGenerator.getInstance().generateRandomBytes(4);
context.writeAndFlush(new PostgreSQLAuthenticationMD5PasswordPacket(md5Salt)); context.writeAndFlush(new PostgreSQLAuthenticationMD5PasswordPacket(md5Salt));
return AuthenticationResult.continued(userName, databaseName); return AuthenticationResult.continued(username, databaseName);
} else { } else {
char messageType = (char) ((PostgreSQLPacketPayload) payload).readInt1(); char messageType = (char) ((PostgreSQLPacketPayload) payload).readInt1();
if ('p' != messageType) { if ('p' != messageType) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(
"PasswordMessage is expected, message type 'p', but not '" + messageType + "'"); PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, String.format("PasswordMessage is expected, message type 'p', but not '%s'", messageType));
context.writeAndFlush(responsePacket); context.writeAndFlush(responsePacket);
context.close(); context.close();
currentAuthResult = AuthenticationResult.continued(); currentAuthResult = AuthenticationResult.continued();
return currentAuthResult; return currentAuthResult;
} }
PostgreSQLPasswordMessagePacket passwordMessagePacket = new PostgreSQLPasswordMessagePacket((PostgreSQLPacketPayload) payload); PostgreSQLPasswordMessagePacket passwordMessagePacket = new PostgreSQLPasswordMessagePacket((PostgreSQLPacketPayload) payload);
PostgreSQLLoginResult loginResult = PostgreSQLAuthenticationHandler.loginWithMd5Password( PostgreSQLLoginResult loginResult = PostgreSQLAuthenticationHandler.loginWithMd5Password(currentAuthResult.getUsername(), currentAuthResult.getDatabase(), md5Salt, passwordMessagePacket);
backendConnection.getUserName(), backendConnection.getSchema(), md5Salt, passwordMessagePacket);
if (PostgreSQLErrorCode.SUCCESSFUL_COMPLETION != loginResult.getErrorCode()) { if (PostgreSQLErrorCode.SUCCESSFUL_COMPLETION != loginResult.getErrorCode()) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(loginResult.getErrorCode(), PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(loginResult.getErrorCode(), loginResult.getErrorMessage());
loginResult.getErrorMessage());
context.writeAndFlush(responsePacket); context.writeAndFlush(responsePacket);
context.close(); context.close();
return AuthenticationResult.continued(); return AuthenticationResult.continued();
...@@ -128,7 +121,7 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin ...@@ -128,7 +121,7 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
} }
} }
private PostgreSQLErrorResponsePacket createPostgreSQLErrorResponsePacket(final PostgreSQLErrorCode errorCode, final String errorMessage) { private PostgreSQLErrorResponsePacket createErrorPacket(final PostgreSQLErrorCode errorCode, final String errorMessage) {
PostgreSQLErrorResponsePacket result = new PostgreSQLErrorResponsePacket(); PostgreSQLErrorResponsePacket result = new PostgreSQLErrorResponsePacket();
result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, "FATAL"); result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, "FATAL");
result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, errorCode.getErrorCode()); result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, errorCode.getErrorCode());
......
...@@ -19,7 +19,6 @@ package org.apache.shardingsphere.proxy.frontend.engine; ...@@ -19,7 +19,6 @@ package org.apache.shardingsphere.proxy.frontend.engine;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload; import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
/** /**
* Authentication engine. * Authentication engine.
...@@ -39,8 +38,7 @@ public interface AuthenticationEngine { ...@@ -39,8 +38,7 @@ public interface AuthenticationEngine {
* *
* @param context channel handler context * @param context channel handler context
* @param payload packet payload * @param payload packet payload
* @param backendConnection backend connection
* @return authentication result * @return authentication result
*/ */
AuthenticationResult auth(ChannelHandlerContext context, PacketPayload payload, BackendConnection backendConnection); AuthenticationResult auth(ChannelHandlerContext context, PacketPayload payload);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册