未验证 提交 cfbec47a 编写于 作者: H Haoran Meng 提交者: GitHub

Merge pull request #7016 from terrymanu/dev

Decouple AuthenticationEngine.handshake() and BackendConnection
......@@ -37,7 +37,7 @@ public final class ConnectionIdGenerator {
*/
public static ConnectionIdGenerator getInstance() {
return INSTANCE;
}
}
/**
* Get next connection ID.
......
......@@ -58,7 +58,7 @@ public final class FrontendChannelInboundHandler extends ChannelInboundHandlerAd
@Override
public void channelActive(final ChannelHandlerContext context) {
ChannelThreadExecutorGroup.getInstance().register(context.channel().id());
databaseProtocolFrontendEngine.getAuthEngine().handshake(context, backendConnection);
backendConnection.setConnectionId(databaseProtocolFrontendEngine.getAuthEngine().handshake(context));
// TODO ref #7013
SingletonFacadeEngine.buildMetrics().ifPresent(metricsHandlerFacade -> metricsHandlerFacade.gaugeIncrement(MetricsLabelEnum.CHANNEL_COUNT.getName()));
}
......
......@@ -58,11 +58,11 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
private String database;
@Override
public void handshake(final ChannelHandlerContext context, final BackendConnection backendConnection) {
int connectionId = ConnectionIdGenerator.getInstance().nextId();
backendConnection.setConnectionId(connectionId);
public int handshake(final ChannelHandlerContext context) {
int result = ConnectionIdGenerator.getInstance().nextId();
connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
context.writeAndFlush(new MySQLHandshakePacket(connectionId, authenticationHandler.getAuthPluginData()));
context.writeAndFlush(new MySQLHandshakePacket(result, authenticationHandler.getAuthPluginData()));
return result;
}
@Override
......
......@@ -91,7 +91,7 @@ public final class MySQLProtocolFrontendEngineTest {
@Test
public void assertHandshake() {
mysqlProtocolFrontendEngine.getAuthEngine().handshake(context, mock(BackendConnection.class));
assertTrue(mysqlProtocolFrontendEngine.getAuthEngine().handshake(context) > 0);
verify(context).writeAndFlush(isA(MySQLHandshakePacket.class));
}
......
......@@ -45,8 +45,8 @@ import java.util.Properties;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
......@@ -54,31 +54,29 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public final class MySQLAuthenticationEngineTest {
private final MySQLAuthenticationHandler authenticationHandler = mock(MySQLAuthenticationHandler.class);
private final MySQLAuthenticationEngine authenticationEngine = new MySQLAuthenticationEngine();
private final byte[] authResponse = {-27, 89, -20, -27, 65, -120, -64, -101, 86, -100, -108, -100, 6, -125, -37, 117, 14, -43, 95, -113};
@Before
public void setUp() throws NoSuchFieldException, IllegalAccessException {
initAuthenticationHandlerForAuthenticationEngine();
}
private void initAuthenticationHandlerForAuthenticationEngine() throws NoSuchFieldException, IllegalAccessException {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("authenticationHandler");
field.setAccessible(true);
field.set(authenticationEngine, authenticationHandler);
}
@Test
public void assertHandshake() {
ChannelHandlerContext context = getContext();
BackendConnection backendConnection = mock(BackendConnection.class);
authenticationEngine.handshake(context, backendConnection);
assertTrue(authenticationEngine.handshake(context) > 0);
verify(context).writeAndFlush(any(MySQLHandshakePacket.class));
verify(backendConnection).setConnectionId(anyInt());
}
@Test
......@@ -100,7 +98,7 @@ public final class MySQLAuthenticationEngineTest {
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class));
assertThat(getAuthResponse(), is(authResponse));
}
@Test
public void assertAuthWithLoginFail() throws NoSuchFieldException, IllegalAccessException {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
......@@ -110,7 +108,7 @@ public final class MySQLAuthenticationEngineTest {
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class));
verify(context).writeAndFlush(any(MySQLErrPacket.class));
}
@Test
public void assertAuthWithAbsentDatabase() throws NoSuchFieldException, IllegalAccessException {
ChannelHandlerContext context = getContext();
......@@ -119,7 +117,7 @@ public final class MySQLAuthenticationEngineTest {
authenticationEngine.auth(context, getPayload("root", "ABSENT DATABASE", authResponse), mock(BackendConnection.class));
verify(context).writeAndFlush(any(MySQLErrPacket.class));
}
@Test
public void assertAuth() throws NoSuchFieldException, IllegalAccessException {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
......@@ -145,19 +143,19 @@ public final class MySQLAuthenticationEngineTest {
when(result.readStringNulByBytes()).thenReturn(authResponse);
return result;
}
private ChannelHandlerContext getContext() {
ChannelHandlerContext result = mock(ChannelHandlerContext.class);
doReturn(getChannel()).when(result).channel();
return result;
}
private Channel getChannel() {
Channel result = mock(Channel.class);
doReturn(getRemoteAddress()).when(result).remoteAddress();
return result;
}
private SocketAddress getRemoteAddress() {
SocketAddress result = mock(SocketAddress.class);
when(result.toString()).thenReturn("127.0.0.1");
......
......@@ -57,10 +57,10 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
private volatile byte[] md5Salt;
@Override
public void handshake(final ChannelHandlerContext context, final BackendConnection backendConnection) {
int connectionId = ConnectionIdGenerator.getInstance().nextId();
backendConnection.setConnectionId(connectionId);
BinaryStatementRegistry.getInstance().register(connectionId);
public int handshake(final ChannelHandlerContext context) {
int result = ConnectionIdGenerator.getInstance().nextId();
BinaryStatementRegistry.getInstance().register(result);
return result;
}
@Override
......
......@@ -30,9 +30,9 @@ public interface AuthenticationEngine {
* Handshake.
*
* @param context channel handler context
* @param backendConnection backend connection
* @return connection ID
*/
void handshake(ChannelHandlerContext context, BackendConnection backendConnection);
int handshake(ChannelHandlerContext context);
/**
* Authentication.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册