未验证 提交 932358b4 编写于 作者: 杨翊 SionYang 提交者: GitHub

Replace handshake and query command in sharding-scaling-mysql (#4824)

上级 6b569758
......@@ -17,14 +17,16 @@
package org.apache.shardingsphere.shardingscaling.mysql.binlog;
import org.apache.shardingsphere.database.protocol.codec.PacketCodec;
import org.apache.shardingsphere.database.protocol.mysql.codec.MySQLPacketCodecEngine;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.text.query.MySQLComQueryPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.codec.MySQLBinlogEventPacketDecoder;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.codec.MySQLCommandPacketDecoder;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.codec.MySQLLengthFieldBasedFrameEncoder;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.event.AbstractBinlogEvent;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.command.BinlogDumpCommandPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.command.QueryCommandPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.command.RegisterSlaveCommandPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.ErrorPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.InternalResultSet;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.OkPacket;
import io.netty.bootstrap.Bootstrap;
......@@ -37,14 +39,12 @@ import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Promise;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
import java.nio.ByteOrder;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutionException;
......@@ -86,8 +86,7 @@ public final class MySQLConnector {
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel socketChannel) {
socketChannel.pipeline().addLast(new LengthFieldBasedFrameDecoder(ByteOrder.LITTLE_ENDIAN, Integer.MAX_VALUE, 0, 3, 1, 4, true));
socketChannel.pipeline().addLast(MySQLLengthFieldBasedFrameEncoder.class.getSimpleName(), new MySQLLengthFieldBasedFrameEncoder());
socketChannel.pipeline().addLast(new PacketCodec(new MySQLPacketCodecEngine()));
socketChannel.pipeline().addLast(new MySQLCommandPacketDecoder());
socketChannel.pipeline().addLast(new MySQLNegotiateHandler(username, password, responseCallback));
socketChannel.pipeline().addLast(new MySQLCommandResponseHandler());
......@@ -106,10 +105,9 @@ public final class MySQLConnector {
*/
public synchronized boolean execute(final String queryString) {
responseCallback = new DefaultPromise<>(eventLoopGroup.next());
QueryCommandPacket queryCommandPacket = new QueryCommandPacket();
queryCommandPacket.setQueryString(queryString);
channel.writeAndFlush(queryCommandPacket);
return null != waitExpectedResponse(OkPacket.class);
MySQLComQueryPacket comQueryPacket = new MySQLComQueryPacket(queryString);
channel.writeAndFlush(comQueryPacket);
return null != waitExpectedResponse(MySQLOKPacket.class);
}
/**
......@@ -120,10 +118,9 @@ public final class MySQLConnector {
*/
public synchronized int executeUpdate(final String queryString) {
responseCallback = new DefaultPromise<>(eventLoopGroup.next());
QueryCommandPacket queryCommandPacket = new QueryCommandPacket();
queryCommandPacket.setQueryString(queryString);
channel.writeAndFlush(queryCommandPacket);
return (int) waitExpectedResponse(OkPacket.class).getAffectedRows();
MySQLComQueryPacket comQueryPacket = new MySQLComQueryPacket(queryString);
channel.writeAndFlush(comQueryPacket);
return (int) waitExpectedResponse(MySQLOKPacket.class).getAffectedRows();
}
/**
......@@ -134,9 +131,8 @@ public final class MySQLConnector {
*/
public synchronized InternalResultSet executeQuery(final String queryString) {
responseCallback = new DefaultPromise<>(eventLoopGroup.next());
QueryCommandPacket queryCommandPacket = new QueryCommandPacket();
queryCommandPacket.setQueryString(queryString);
channel.writeAndFlush(queryCommandPacket);
MySQLComQueryPacket comQueryPacket = new MySQLComQueryPacket(queryString);
channel.writeAndFlush(comQueryPacket);
return waitExpectedResponse(InternalResultSet.class);
}
......@@ -173,7 +169,7 @@ public final class MySQLConnector {
return 0;
}
InternalResultSet resultSet = executeQuery("SELECT @@GLOBAL.BINLOG_CHECKSUM");
String checksumType = resultSet.getFieldValues().get(0).getColumns().get(0);
String checksumType = resultSet.getFieldValues().get(0).getData().get(0).toString();
switch (checksumType) {
case "None":
return 0;
......@@ -215,8 +211,8 @@ public final class MySQLConnector {
if (type.equals(response.getClass())) {
return (T) response;
}
if (response instanceof ErrorPacket) {
throw new RuntimeException(((ErrorPacket) response).getMessage());
if (response instanceof MySQLErrPacket) {
throw new RuntimeException(((MySQLErrPacket) response).getErrorMessage());
}
throw new RuntimeException("unexpected response type");
} catch (InterruptedException | ExecutionException e) {
......
......@@ -17,11 +17,15 @@
package org.apache.shardingsphere.shardingscaling.mysql.binlog;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.auth.ClientAuthenticationPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.auth.HandshakeInitializationPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.ErrorPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.OkPacket;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLAuthenticationMethod;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCapabilityFlag;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.concurrent.Promise;
......@@ -32,6 +36,10 @@ import io.netty.util.concurrent.Promise;
@RequiredArgsConstructor
public final class MySQLNegotiateHandler extends ChannelInboundHandlerAdapter {
private static final int MAX_PACKET_SIZE = 1 << 24;
private static final int CHARACTER_SET = 33;
private final String username;
private final String password;
......@@ -42,37 +50,36 @@ public final class MySQLNegotiateHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof HandshakeInitializationPacket) {
HandshakeInitializationPacket handshake = (HandshakeInitializationPacket) msg;
ClientAuthenticationPacket clientAuth = new ClientAuthenticationPacket();
clientAuth.setSequenceNumber((byte) (handshake.getSequenceNumber() + 1));
clientAuth.setCharsetNumber((byte) 33);
clientAuth.setUsername(username);
clientAuth.setPassword(password);
clientAuth.setServerCapabilities(handshake.getServerCapabilities());
// use default database
clientAuth.setDatabaseName("mysql");
clientAuth.setAuthPluginData(joinAndCreateAuthPluginData(handshake));
clientAuth.setAuthPluginName(handshake.getAuthPluginName());
ctx.channel().writeAndFlush(clientAuth);
if (msg instanceof MySQLHandshakePacket) {
MySQLHandshakePacket handshake = (MySQLHandshakePacket) msg;
MySQLHandshakeResponse41Packet handshakeResponsePacket = new MySQLHandshakeResponse41Packet(1, MAX_PACKET_SIZE, CHARACTER_SET, username);
handshakeResponsePacket.setAuthResponse(generateAuthResponse(handshake.getAuthPluginData().getAuthPluginData()));
handshakeResponsePacket.setCapabilityFlags(generateClientCapability());
handshakeResponsePacket.setDatabase("mysql");
handshakeResponsePacket.setAuthPluginName(MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION);
ctx.channel().writeAndFlush(handshakeResponsePacket);
serverInfo = new ServerInfo();
serverInfo.setServerVersion(new ServerVersion(handshake.getServerVersion()));
return;
}
if (msg instanceof OkPacket) {
if (msg instanceof MySQLOKPacket) {
ctx.channel().pipeline().remove(this);
authResultCallback.setSuccess(serverInfo);
return;
}
ErrorPacket error = (ErrorPacket) msg;
MySQLErrPacket error = (MySQLErrPacket) msg;
ctx.channel().close();
throw new RuntimeException(error.getMessage());
throw new RuntimeException(error.getErrorMessage());
}
private int generateClientCapability() {
return MySQLCapabilityFlag.calculateCapabilityFlags(MySQLCapabilityFlag.CLIENT_LONG_PASSWORD, MySQLCapabilityFlag.CLIENT_LONG_FLAG,
MySQLCapabilityFlag.CLIENT_PROTOCOL_41, MySQLCapabilityFlag.CLIENT_INTERACTIVE, MySQLCapabilityFlag.CLIENT_TRANSACTIONS,
MySQLCapabilityFlag.CLIENT_SECURE_CONNECTION, MySQLCapabilityFlag.CLIENT_MULTI_STATEMENTS, MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH);
}
private byte[] joinAndCreateAuthPluginData(final HandshakeInitializationPacket handshakePacket) {
byte[] result = new byte[handshakePacket.getAuthPluginDataPart1().length + handshakePacket.getAuthPluginDataPart2().length];
System.arraycopy(handshakePacket.getAuthPluginDataPart1(), 0, result, 0, handshakePacket.getAuthPluginDataPart1().length);
System.arraycopy(handshakePacket.getAuthPluginDataPart2(), 0, result, handshakePacket.getAuthPluginDataPart1().length, handshakePacket.getAuthPluginDataPart2().length);
return result;
@SneakyThrows
private byte[] generateAuthResponse(final byte[] authPluginData) {
return (null == password || 0 == password.length()) ? new byte[0] : MySQLPasswordEncryptor.encryptWithMySQL41(password.getBytes(), authPluginData);
}
}
......@@ -17,17 +17,18 @@
package org.apache.shardingsphere.shardingscaling.mysql.binlog.codec;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.auth.HandshakeInitializationPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.MySQLFieldCountPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.text.MySQLTextResultSetRowPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLEofPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.database.protocol.mysql.payload.MySQLPacketPayload;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.EofPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.ErrorPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.FieldPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.InternalResultSet;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.OkPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.ResultSetHeaderPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.RowDataPacket;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
......@@ -38,86 +39,78 @@ import java.util.List;
@Slf4j
public final class MySQLCommandPacketDecoder extends ByteToMessageDecoder {
private enum States { Initiate, ResponsePacket, FieldPacket, RowDataPacket }
private enum States { ResponsePacket, FieldPacket, RowDataPacket }
private States currentState = States.Initiate;
private States currentState = States.ResponsePacket;
private boolean auth;
private InternalResultSet internalResultSet;
@Override
protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List<Object> out) {
// first packet from server is handshake initialization packet
if (States.Initiate.equals(currentState)) {
out.add(decodeHandshakeInitializationPacket(in));
currentState = States.ResponsePacket;
return;
MySQLPacketPayload payload = new MySQLPacketPayload(in);
if (!auth) {
out.add(decodeHandshakePacket(payload));
auth = true;
} else {
decodeCommandPacket(payload, out);
}
}
private void decodeCommandPacket(final MySQLPacketPayload payload, final List<Object> out) {
if (States.FieldPacket.equals(currentState)) {
decodeFieldPacket(in);
decodeFieldPacket(payload);
return;
}
if (States.RowDataPacket.equals(currentState)) {
decodeRowDataPacket(in, out);
decodeRowDataPacket(payload, out);
return;
}
decodeResponsePacket(in, out);
decodeResponsePacket(payload, out);
}
private HandshakeInitializationPacket decodeHandshakeInitializationPacket(final ByteBuf in) {
HandshakeInitializationPacket result = new HandshakeInitializationPacket();
result.fromByteBuf(in);
if (PacketConstants.PROTOCOL_VERSION != result.getProtocolVersion()) {
throw new UnsupportedOperationException();
}
private MySQLHandshakePacket decodeHandshakePacket(final MySQLPacketPayload payload) {
MySQLHandshakePacket result = new MySQLHandshakePacket(payload);
if (!AuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.equals(result.getAuthPluginName())) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException("Only supported SECURE_PASSWORD_AUTHENTICATION server");
}
return result;
}
private void decodeFieldPacket(final ByteBuf in) {
if (PacketConstants.EOF_PACKET_MARK != in.getByte(0)) {
FieldPacket fieldPacket = new FieldPacket();
fieldPacket.fromByteBuf(in);
internalResultSet.getFieldDescriptors().add(fieldPacket);
private void decodeFieldPacket(final MySQLPacketPayload payload) {
if (PacketConstants.EOF_PACKET_MARK != payload.getByteBuf().getByte(1)) {
internalResultSet.getFieldDescriptors().add(new MySQLColumnDefinition41Packet(payload));
} else {
EofPacket eofPacket = new EofPacket();
eofPacket.fromByteBuf(in);
new MySQLEofPacket(payload);
currentState = States.RowDataPacket;
}
}
private void decodeRowDataPacket(final ByteBuf in, final List<Object> out) {
if (PacketConstants.EOF_PACKET_MARK != in.getByte(0)) {
RowDataPacket rowDataPacket = new RowDataPacket();
rowDataPacket.fromByteBuf(in);
internalResultSet.getFieldValues().add(rowDataPacket);
private void decodeRowDataPacket(final MySQLPacketPayload payload, final List<Object> out) {
if (PacketConstants.EOF_PACKET_MARK != payload.getByteBuf().getByte(1)) {
internalResultSet.getFieldValues().add(new MySQLTextResultSetRowPacket(payload, internalResultSet.getHeader().getColumnCount()));
} else {
EofPacket eofPacket = new EofPacket();
eofPacket.fromByteBuf(in);
new MySQLEofPacket(payload);
out.add(internalResultSet);
currentState = States.ResponsePacket;
internalResultSet = null;
}
}
private void decodeResponsePacket(final ByteBuf in, final List<Object> out) {
switch (in.getByte(0)) {
private void decodeResponsePacket(final MySQLPacketPayload payload, final List<Object> out) {
switch (payload.getByteBuf().getByte(1)) {
case PacketConstants.ERR_PACKET_MARK:
ErrorPacket error = new ErrorPacket();
error.fromByteBuf(in);
out.add(error);
out.add(new MySQLErrPacket(payload));
break;
case PacketConstants.OK_PACKET_MARK:
OkPacket ok = new OkPacket();
ok.fromByteBuf(in);
out.add(ok);
out.add(new MySQLOKPacket(payload));
break;
default:
ResultSetHeaderPacket resultSetHeaderPacket = new ResultSetHeaderPacket();
resultSetHeaderPacket.fromByteBuf(in);
MySQLFieldCountPacket fieldCountPacket = new MySQLFieldCountPacket(payload);
currentState = States.FieldPacket;
internalResultSet = new InternalResultSet(resultSetHeaderPacket);
internalResultSet = new InternalResultSet(fieldCountPacket);
}
}
}
......@@ -19,10 +19,14 @@ package org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.MySQLFieldCountPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.text.MySQLTextResultSetRowPacket;
import java.util.ArrayList;
import java.util.List;
/**
* Internal Result Set.
*/
......@@ -30,9 +34,9 @@ import java.util.List;
@Getter
public final class InternalResultSet {
private final ResultSetHeaderPacket header;
private final MySQLFieldCountPacket header;
private List<FieldPacket> fieldDescriptors = new ArrayList<>();
private List<MySQLColumnDefinition41Packet> fieldDescriptors = new ArrayList<>();
private List<RowDataPacket> fieldValues = new ArrayList<>();
private List<MySQLTextResultSetRowPacket> fieldValues = new ArrayList<>();
}
......@@ -20,8 +20,9 @@ package org.apache.shardingsphere.shardingscaling.mysql.binlog;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.util.concurrent.Promise;
import org.apache.shardingsphere.database.protocol.mysql.packet.command.query.text.query.MySQLComQueryPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.command.BinlogDumpCommandPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.command.QueryCommandPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.command.RegisterSlaveCommandPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.InternalResultSet;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.OkPacket;
......@@ -73,20 +74,20 @@ public final class MySQLConnectorTest {
@Test
public void assertExecute() throws NoSuchFieldException, IllegalAccessException {
mockChannelResponse(new OkPacket());
mockChannelResponse(new MySQLOKPacket(0));
ReflectionUtil.setFieldValueToClass(mySQLConnector, "channel", channel);
assertTrue(mySQLConnector.execute(""));
verify(channel).writeAndFlush(ArgumentMatchers.any(QueryCommandPacket.class));
verify(channel).writeAndFlush(ArgumentMatchers.any(MySQLComQueryPacket.class));
}
@Test
public void assertExecuteUpdate() throws NoSuchFieldException, IllegalAccessException {
OkPacket expected = new OkPacket();
MySQLOKPacket expected = new MySQLOKPacket(0, 10, 0);
ReflectionUtil.setFieldValueToClass(expected, "affectedRows", 10);
mockChannelResponse(expected);
ReflectionUtil.setFieldValueToClass(mySQLConnector, "channel", channel);
assertThat(mySQLConnector.executeUpdate(""), is(10));
verify(channel).writeAndFlush(ArgumentMatchers.any(QueryCommandPacket.class));
verify(channel).writeAndFlush(ArgumentMatchers.any(MySQLComQueryPacket.class));
}
@Test
......@@ -95,7 +96,7 @@ public final class MySQLConnectorTest {
mockChannelResponse(expected);
ReflectionUtil.setFieldValueToClass(mySQLConnector, "channel", channel);
assertThat(mySQLConnector.executeQuery(""), is(expected));
verify(channel).writeAndFlush(ArgumentMatchers.any(QueryCommandPacket.class));
verify(channel).writeAndFlush(ArgumentMatchers.any(MySQLComQueryPacket.class));
}
@Test
......
......@@ -21,10 +21,14 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.util.concurrent.Promise;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.auth.ClientAuthenticationPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.auth.HandshakeInitializationPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.ErrorPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.OkPacket;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLAuthenticationMethod;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLServerErrorCode;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthPluginData;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.shardingscaling.utils.ReflectionUtil;
import org.junit.Before;
......@@ -42,8 +46,6 @@ import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public final class MySQLNegotiateHandlerTest {
private static final String SERVER_VERSION = "5.7.13-log";
private static final String USER_NAME = "username";
private static final String PASSWORD = "password";
......@@ -71,23 +73,19 @@ public final class MySQLNegotiateHandlerTest {
@Test
public void assertChannelReadHandshakeInitPacket() throws NoSuchFieldException, IllegalAccessException {
HandshakeInitializationPacket handshakeInitializationPacket = new HandshakeInitializationPacket();
handshakeInitializationPacket.setServerVersion(SERVER_VERSION);
handshakeInitializationPacket.setAuthPluginName("");
handshakeInitializationPacket.setServerCapabilities(1);
handshakeInitializationPacket.setAuthPluginDataPart1(new byte[8]);
handshakeInitializationPacket.setAuthPluginDataPart2(new byte[12]);
mySQLNegotiateHandler.channelRead(channelHandlerContext, handshakeInitializationPacket);
verify(channel).writeAndFlush(ArgumentMatchers.any(ClientAuthenticationPacket.class));
MySQLHandshakePacket handshakePacket = new MySQLHandshakePacket(0, new MySQLAuthPluginData(new byte[8], new byte[12]));
handshakePacket.setAuthPluginName(MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION);
mySQLNegotiateHandler.channelRead(channelHandlerContext, handshakePacket);
verify(channel).writeAndFlush(ArgumentMatchers.any(MySQLHandshakeResponse41Packet.class));
ServerInfo serverInfo = ReflectionUtil.getFieldValueFromClass(mySQLNegotiateHandler, "serverInfo", ServerInfo.class);
assertThat(serverInfo.getServerVersion().getMajor(), is(5));
assertThat(serverInfo.getServerVersion().getMinor(), is(7));
assertThat(serverInfo.getServerVersion().getSeries(), is(13));
assertThat(serverInfo.getServerVersion().getMinor(), is(6));
assertThat(serverInfo.getServerVersion().getSeries(), is(4));
}
@Test
public void assertChannelReadOkPacket() throws NoSuchFieldException, IllegalAccessException {
OkPacket okPacket = new OkPacket();
MySQLOKPacket okPacket = new MySQLOKPacket(0);
ServerInfo serverInfo = new ServerInfo();
ReflectionUtil.setFieldValueToClass(mySQLNegotiateHandler, "serverInfo", serverInfo);
mySQLNegotiateHandler.channelRead(channelHandlerContext, okPacket);
......@@ -97,7 +95,7 @@ public final class MySQLNegotiateHandlerTest {
@Test(expected = RuntimeException.class)
public void assertChannelReadErrorPacket() {
ErrorPacket errorPacket = new ErrorPacket();
MySQLErrPacket errorPacket = new MySQLErrPacket(0, MySQLServerErrorCode.ER_NO_DB_ERROR);
mySQLNegotiateHandler.channelRead(channelHandlerContext, errorPacket);
}
}
......@@ -17,14 +17,19 @@
package org.apache.shardingsphere.shardingscaling.mysql.binlog.codec;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.auth.HandshakeInitializationPacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.ErrorPacket;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLServerInfo;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLStatusFlag;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLEofPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.InternalResultSet;
import org.apache.shardingsphere.shardingscaling.mysql.binlog.packet.response.OkPacket;
import org.apache.shardingsphere.shardingscaling.utils.ReflectionUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
......@@ -43,7 +48,7 @@ public final class MySQLCommandPacketDecoderTest {
@Mock
private ByteBuf byteBuf;
@Test(expected = UnsupportedOperationException.class)
@Test(expected = IllegalArgumentException.class)
public void assertDecodeUnsupportedProtocolVersion() {
MySQLCommandPacketDecoder commandPacketDecoder = new MySQLCommandPacketDecoder();
commandPacketDecoder.decode(null, byteBuf, null);
......@@ -51,34 +56,22 @@ public final class MySQLCommandPacketDecoderTest {
@Test(expected = UnsupportedOperationException.class)
public void assertDecodeUnsupportedAuthenticationMethod() {
when(byteBuf.readByte()).thenReturn((byte) 0, (byte) MySQLServerInfo.PROTOCOL_VERSION);
when(byteBuf.readShortLE()).thenReturn((short) MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue());
MySQLCommandPacketDecoder commandPacketDecoder = new MySQLCommandPacketDecoder();
when(byteBuf.readUnsignedByte()).thenReturn((short) PacketConstants.PROTOCOL_VERSION);
commandPacketDecoder.decode(null, byteBuf, null);
}
@Test
public void assertDecode() {
public void assertDecodeHandshakePacket() {
MySQLCommandPacketDecoder commandPacketDecoder = new MySQLCommandPacketDecoder();
List<Object> actual = new ArrayList<>();
commandPacketDecoder.decode(null, mockHandshakePacket(), actual);
assertInitial(actual);
actual.clear();
commandPacketDecoder.decode(null, mockOkPacket(), actual);
assertPacketByType(actual, OkPacket.class);
actual.clear();
commandPacketDecoder.decode(null, mockErrPacket(), actual);
assertPacketByType(actual, ErrorPacket.class);
actual.clear();
commandPacketDecoder.decode(null, mockResultSetPacket(), actual);
commandPacketDecoder.decode(null, mockResultSetPacket(), actual);
commandPacketDecoder.decode(null, mockEofPacket(), actual);
commandPacketDecoder.decode(null, mockResultSetPacket(), actual);
commandPacketDecoder.decode(null, mockEofPacket(), actual);
assertPacketByType(actual, InternalResultSet.class);
assertHandshakePacket(actual);
}
private ByteBuf mockHandshakePacket() {
String handshakePacket = "0a352e372e32312d6c6f6700090000004a592a1f725a0d0900fff7210200ff8115000000000000000000001a437b30323a4d2b514b5870006d"
String handshakePacket = "000a352e372e32312d6c6f6700090000004a592a1f725a0d0900fff7210200ff8115000000000000000000001a437b30323a4d2b514b5870006d"
+ "7973716c5f6e61746976655f70617373776f72640000000002000000";
byte[] handshakePacketBytes = ByteBufUtil.decodeHexDump(handshakePacket);
ByteBuf result = Unpooled.buffer(handshakePacketBytes.length);
......@@ -86,37 +79,82 @@ public final class MySQLCommandPacketDecoderTest {
return result;
}
private void assertInitial(final List<Object> actual) {
private void assertHandshakePacket(final List<Object> actual) {
assertThat(actual.size(), is(1));
assertThat(actual.get(0), instanceOf(HandshakeInitializationPacket.class));
HandshakeInitializationPacket actualPacket = (HandshakeInitializationPacket) actual.get(0);
assertThat(actualPacket.getProtocolVersion(), is((short) 0x0a));
assertThat(actual.get(0), instanceOf(MySQLHandshakePacket.class));
MySQLHandshakePacket actualPacket = (MySQLHandshakePacket) actual.get(0);
assertThat(actualPacket.getProtocolVersion(), is(0x0a));
assertThat(actualPacket.getServerVersion(), is("5.7.21-log"));
assertThat(actualPacket.getThreadId(), is(9L));
assertThat(actualPacket.getServerCharsetSet(), is((short) 33));
assertThat(actualPacket.getServerStatus(), is(2));
assertThat(actualPacket.getServerCapabilities(), is(63487));
assertThat(actualPacket.getServerCapabilities2(), is(33279));
assertThat(actualPacket.getConnectionId(), is(9));
assertThat(actualPacket.getCharacterSet(), is(33));
assertThat(actualPacket.getStatusFlag().getValue(), is(2));
assertThat(actualPacket.getCapabilityFlagsLower(), is(63487));
assertThat(actualPacket.getCapabilityFlagsUpper(), is(33279));
assertThat(actualPacket.getAuthPluginName(), is("mysql_native_password"));
}
@Test
public void assertDecodeOkPacket() throws NoSuchFieldException, IllegalAccessException {
MySQLCommandPacketDecoder commandPacketDecoder = new MySQLCommandPacketDecoder();
List<Object> actual = new ArrayList<>();
ReflectionUtil.setFieldValueToClass(commandPacketDecoder, "auth", true);
commandPacketDecoder.decode(null, mockOkPacket(), actual);
assertPacketByType(actual, MySQLOKPacket.class);
}
private ByteBuf mockOkPacket() {
when(byteBuf.getByte(0)).thenReturn(PacketConstants.OK_PACKET_MARK);
when(byteBuf.readByte()).thenReturn((byte) 0, (byte) MySQLOKPacket.HEADER);
when(byteBuf.getByte(1)).thenReturn((byte) MySQLOKPacket.HEADER);
return byteBuf;
}
@Test
public void assertDecodeErrPacket() throws NoSuchFieldException, IllegalAccessException {
MySQLCommandPacketDecoder commandPacketDecoder = new MySQLCommandPacketDecoder();
List<Object> actual = new ArrayList<>();
ReflectionUtil.setFieldValueToClass(commandPacketDecoder, "auth", true);
commandPacketDecoder.decode(null, mockErrPacket(), actual);
assertPacketByType(actual, MySQLErrPacket.class);
}
private ByteBuf mockErrPacket() {
when(byteBuf.getByte(0)).thenReturn(PacketConstants.ERR_PACKET_MARK);
when(byteBuf.getByte(1)).thenReturn((byte) MySQLErrPacket.HEADER);
when(byteBuf.readByte()).thenReturn((byte) 0, (byte) MySQLErrPacket.HEADER);
return byteBuf;
}
@Test
public void assertDecodeQueryCommPacket() throws NoSuchFieldException, IllegalAccessException {
MySQLCommandPacketDecoder commandPacketDecoder = new MySQLCommandPacketDecoder();
List<Object> actual = new ArrayList<>();
ReflectionUtil.setFieldValueToClass(commandPacketDecoder, "auth", true);
commandPacketDecoder.decode(null, mockEmptyResultSetPacket(), actual);
commandPacketDecoder.decode(null, mockFieldDefinition41Packet(), actual);
commandPacketDecoder.decode(null, mockEofPacket(), actual);
commandPacketDecoder.decode(null, mockEmptyResultSetPacket(), actual);
commandPacketDecoder.decode(null, mockEofPacket(), actual);
assertPacketByType(actual, InternalResultSet.class);
}
private ByteBuf mockEmptyResultSetPacket() {
when(byteBuf.getByte(1)).thenReturn((byte) 3);
return byteBuf;
}
private ByteBuf mockResultSetPacket() {
when(byteBuf.getByte(0)).thenReturn((byte) 1);
private ByteBuf mockFieldDefinition41Packet() {
when(byteBuf.getByte(1)).thenReturn((byte) 3);
when(byteBuf.readByte()).thenReturn((byte) 0, (byte) 3, (byte) 0x0c);
when(byteBuf.readBytes(new byte[3])).then(invocationOnMock -> {
byte[] input = invocationOnMock.getArgument(0);
System.arraycopy("def".getBytes(), 0, input, 0, input.length);
return byteBuf;
});
return byteBuf;
}
private ByteBuf mockEofPacket() {
when(byteBuf.getByte(0)).thenReturn(PacketConstants.EOF_PACKET_MARK);
when(byteBuf.getByte(1)).thenReturn((byte) MySQLEofPacket.HEADER);
when(byteBuf.readByte()).thenReturn((byte) 0, (byte) MySQLEofPacket.HEADER);
return byteBuf;
}
......
......@@ -100,8 +100,14 @@ public enum MySQLCapabilityFlag {
return 0;
}
/**
* Calculate capability flags.
*
* @param capabilities single capabilities of need to be calculated
* @return combined capabilities
*/
// TODO use xor to calculate lower and upper
private static int calculateCapabilityFlags(final MySQLCapabilityFlag... capabilities) {
public static int calculateCapabilityFlags(final MySQLCapabilityFlag... capabilities) {
int result = 0;
for (MySQLCapabilityFlag each : capabilities) {
result |= each.value;
......
......@@ -33,11 +33,11 @@ import java.util.List;
* @see <a href="https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow">ResultsetRow</a>
*/
@RequiredArgsConstructor
@Getter
public final class MySQLTextResultSetRowPacket implements MySQLPacket {
private static final int NULL = 0xfb;
@Getter
private final int sequenceId;
private final List<Object> data;
......
......@@ -34,6 +34,11 @@ public final class MySQLComQueryPacket extends MySQLCommandPacket {
private final String sql;
public MySQLComQueryPacket(final String sql) {
super(MySQLCommandPacketType.COM_QUERY);
this.sql = sql;
}
public MySQLComQueryPacket(final MySQLPacketPayload payload) {
super(MySQLCommandPacketType.COM_QUERY);
sql = payload.readStringEOF();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册