未验证 提交 56104b47 编写于 作者: A Andrei Pechkurov 提交者: GitHub

chore(core): add abstractions for encrypted sockets (#3678)

上级 e5a821cb
......@@ -45,7 +45,10 @@ public class PongMain {
// worker pool, which would handle jobs
final WorkerPool workerPool = new WorkerPool(() -> 1);
// event loop that accepts connections and publishes network events to event queue
final IODispatcher<PongConnectionContext> dispatcher = IODispatchers.create(dispatcherConf, new IOContextFactoryImpl<>(PongConnectionContext::new, 8));
final IODispatcher<PongConnectionContext> dispatcher = IODispatchers.create(
dispatcherConf,
new IOContextFactoryImpl<>(() -> new PongConnectionContext(PlainSocketFactory.INSTANCE, dispatcherConf.getNetworkFacade(), LOG), 8)
);
// event queue processor
final PongRequestProcessor processor = new PongRequestProcessor();
// event loop job
......@@ -65,6 +68,10 @@ public class PongMain {
private final DirectByteCharSequence flyweight = new DirectByteCharSequence();
private int writtenLen;
protected PongConnectionContext(SocketFactory socketFactory, NetworkFacade nf, Log log) {
super(socketFactory, nf, log);
}
@Override
public void clear() {
buf = bufStart;
......
......@@ -32,27 +32,45 @@ import io.questdb.cutlass.http.DefaultHttpAuthenticatorFactory;
import io.questdb.cutlass.http.HttpAuthenticatorFactory;
import io.questdb.cutlass.pgwire.DefaultPgWireAuthenticatorFactory;
import io.questdb.cutlass.pgwire.PgWireAuthenticatorFactory;
import io.questdb.network.PlainSocketFactory;
import io.questdb.network.SocketFactory;
import org.jetbrains.annotations.NotNull;
public class DefaultFactoryProvider implements FactoryProvider {
public static final DefaultFactoryProvider INSTANCE = new DefaultFactoryProvider();
@Override
public HttpAuthenticatorFactory getHttpAuthenticatorFactory() {
public @NotNull HttpAuthenticatorFactory getHttpAuthenticatorFactory() {
return DefaultHttpAuthenticatorFactory.INSTANCE;
}
@Override
public LineAuthenticatorFactory getLineAuthenticatorFactory() {
public @NotNull SocketFactory getHttpSocketFactory() {
return PlainSocketFactory.INSTANCE;
}
@Override
public @NotNull LineAuthenticatorFactory getLineAuthenticatorFactory() {
return DefaultLineAuthenticatorFactory.INSTANCE;
}
@Override
public PgWireAuthenticatorFactory getPgWireAuthenticatorFactory() {
public @NotNull SocketFactory getLineSocketFactory() {
return PlainSocketFactory.INSTANCE;
}
@Override
public @NotNull SocketFactory getPGWireSocketFactory() {
return PlainSocketFactory.INSTANCE;
}
@Override
public @NotNull PgWireAuthenticatorFactory getPgWireAuthenticatorFactory() {
return DefaultPgWireAuthenticatorFactory.INSTANCE;
}
@Override
public SecurityContextFactory getSecurityContextFactory() {
public @NotNull SecurityContextFactory getSecurityContextFactory() {
return AllowAllSecurityContextFactory.INSTANCE;
}
}
......@@ -28,18 +28,33 @@ import io.questdb.cairo.security.SecurityContextFactory;
import io.questdb.cutlass.auth.LineAuthenticatorFactory;
import io.questdb.cutlass.http.HttpAuthenticatorFactory;
import io.questdb.cutlass.pgwire.PgWireAuthenticatorFactory;
import io.questdb.network.SocketFactory;
import io.questdb.std.QuietCloseable;
import org.jetbrains.annotations.NotNull;
public interface FactoryProvider extends QuietCloseable {
@Override
default void close() {
}
@NotNull
HttpAuthenticatorFactory getHttpAuthenticatorFactory();
@NotNull
SocketFactory getHttpSocketFactory();
@NotNull
LineAuthenticatorFactory getLineAuthenticatorFactory();
@NotNull
SocketFactory getLineSocketFactory();
@NotNull
SocketFactory getPGWireSocketFactory();
@NotNull
PgWireAuthenticatorFactory getPgWireAuthenticatorFactory();
@NotNull
SecurityContextFactory getSecurityContextFactory();
}
......@@ -29,9 +29,11 @@ import io.questdb.cutlass.auth.LineAuthenticatorFactory;
import io.questdb.cutlass.http.DefaultHttpAuthenticatorFactory;
import io.questdb.cutlass.http.HttpAuthenticatorFactory;
import io.questdb.cutlass.pgwire.PgWireAuthenticatorFactory;
import io.questdb.network.PlainSocketFactory;
import io.questdb.network.SocketFactory;
import io.questdb.std.Misc;
import io.questdb.std.str.DirectByteCharSink;
import org.jetbrains.annotations.NotNull;
public class FactoryProviderImpl implements FactoryProvider {
private final LineAuthenticatorFactory lineAuthenticatorFactory;
......@@ -55,22 +57,37 @@ public class FactoryProviderImpl implements FactoryProvider {
}
@Override
public HttpAuthenticatorFactory getHttpAuthenticatorFactory() {
public @NotNull HttpAuthenticatorFactory getHttpAuthenticatorFactory() {
return DefaultHttpAuthenticatorFactory.INSTANCE;
}
@Override
public LineAuthenticatorFactory getLineAuthenticatorFactory() {
public @NotNull SocketFactory getHttpSocketFactory() {
return PlainSocketFactory.INSTANCE;
}
@Override
public @NotNull LineAuthenticatorFactory getLineAuthenticatorFactory() {
return lineAuthenticatorFactory;
}
@Override
public PgWireAuthenticatorFactory getPgWireAuthenticatorFactory() {
public @NotNull SocketFactory getLineSocketFactory() {
return PlainSocketFactory.INSTANCE;
}
@Override
public @NotNull SocketFactory getPGWireSocketFactory() {
return PlainSocketFactory.INSTANCE;
}
@Override
public @NotNull PgWireAuthenticatorFactory getPgWireAuthenticatorFactory() {
return pgWireAuthenticatorFactory;
}
@Override
public SecurityContextFactory getSecurityContextFactory() {
public @NotNull SecurityContextFactory getSecurityContextFactory() {
return securityContextFactory;
}
}
......@@ -2733,11 +2733,6 @@ public class PropServerConfiguration implements ServerConfiguration {
return httpNetConnectionHint;
}
@Override
public int getInitialBias() {
return IOOperation.READ;
}
@Override
public KqueueFacade getKqueueFacade() {
return KqueueFacadeImpl.INSTANCE;
......@@ -2820,11 +2815,6 @@ public class PropServerConfiguration implements ServerConfiguration {
return httpMinNetConnectionHint;
}
@Override
public int getInitialBias() {
return IOOperation.READ;
}
@Override
public KqueueFacade getKqueueFacade() {
return KqueueFacadeImpl.INSTANCE;
......@@ -2878,6 +2868,11 @@ public class PropServerConfiguration implements ServerConfiguration {
return httpMinIODispatcherConfiguration;
}
@Override
public FactoryProvider getFactoryProvider() {
return factoryProvider;
}
@Override
public HttpContextConfiguration getHttpContextConfiguration() {
return httpContextConfiguration;
......@@ -2946,6 +2941,11 @@ public class PropServerConfiguration implements ServerConfiguration {
return httpIODispatcherConfiguration;
}
@Override
public FactoryProvider getFactoryProvider() {
return factoryProvider;
}
@Override
public HttpContextConfiguration getHttpContextConfiguration() {
return httpContextConfiguration;
......@@ -3265,11 +3265,6 @@ public class PropServerConfiguration implements ServerConfiguration {
public boolean isSymbolAsFieldSupported() {
return symbolAsFieldSupported;
}
@Override
public boolean readOnlySecurityContext() {
return httpReadOnlySecurityContext || isReadOnlyInstance;
}
}
private class PropLineTcpReceiverIODispatcherConfiguration implements IODispatcherConfiguration {
......@@ -3309,11 +3304,6 @@ public class PropServerConfiguration implements ServerConfiguration {
return lineTcpNetConnectionHint;
}
@Override
public int getInitialBias() {
return BIAS_READ;
}
@Override
public KqueueFacade getKqueueFacade() {
return KqueueFacadeImpl.INSTANCE;
......@@ -3742,11 +3732,6 @@ public class PropServerConfiguration implements ServerConfiguration {
return pgNetConnectionHint;
}
@Override
public int getInitialBias() {
return BIAS_READ;
}
@Override
public KqueueFacade getKqueueFacade() {
return KqueueFacadeImpl.INSTANCE;
......
......@@ -210,10 +210,7 @@ public class ServerMain implements Closeable {
final String rootDir = new File(configuration.getCairoConfiguration().getRoot()).getParent();
final String absPath = new File(rootDir, configuration.getLineTcpReceiverConfiguration().getAuthDB()).getAbsolutePath();
CharSequenceObjHashMap<PublicKey> authDb = AuthUtils.loadAuthDb(absPath);
authenticatorFactory = new EllipticCurveAuthenticatorFactory(
configuration.getLineTcpReceiverConfiguration().getNetworkFacade(),
() -> new StaticChallengeResponseMatcher(authDb)
);
authenticatorFactory = new EllipticCurveAuthenticatorFactory(() -> new StaticChallengeResponseMatcher(authDb));
} else {
authenticatorFactory = DefaultLineAuthenticatorFactory.INSTANCE;
}
......
......@@ -9,7 +9,6 @@ public class DefaultDdlListener implements DdlListener {
@Override
public void onColumnRenamed(SecurityContext securityContext, TableToken tableToken, CharSequence oldColumnName, CharSequence newColumnName) {
}
@Override
......
......@@ -93,7 +93,7 @@ public class NetworkSqlExecutionCircuitBreaker implements SqlExecutionCircuitBre
@Override
public void close() {
buffer = Unsafe.free(buffer, bufferSize, this.memoryTag);
buffer = Unsafe.free(buffer, bufferSize, memoryTag);
fd = -1;
}
......@@ -149,23 +149,23 @@ public class NetworkSqlExecutionCircuitBreaker implements SqlExecutionCircuitBre
this.timeout = timeout;
}
@Override
public void statefulThrowExceptionIfTripped() {
public void statefulThrowExceptionIfTimeout() {
// Same as statefulThrowExceptionIfTripped but does not check the connection state.
// Useful to check timeout before trying to send something on the connection.
if (testCount < throttle) {
testCount++;
} else {
statefulThrowExceptionIfTrippedNoThrottle();
testCount = 0;
testTimeout();
}
}
public void statefulThrowExceptionIfTimeout() {
// Same as statefulThrowExceptionIfTripped but does not check the connection state.
// Useful to check timeout before trying to send something on the connection.
@Override
public void statefulThrowExceptionIfTripped() {
if (testCount < throttle) {
testCount++;
} else {
testCount = 0;
testTimeout();
statefulThrowExceptionIfTrippedNoThrottle();
}
}
......
......@@ -38,7 +38,7 @@ public class DefaultWalListener implements WalListener {
}
@Override
public void segmentClosed(final TableToken tabletoken, long txn, int walId, int segmentId) {
public void segmentClosed(final TableToken tableToken, long txn, int walId, int segmentId) {
}
@Override
......
......@@ -32,7 +32,7 @@ public interface WalListener {
void nonDataTxnCommitted(TableToken tableToken, long txn, long timestamp);
void segmentClosed(final TableToken tabletoken, long txn, int walId, int segmentId);
void segmentClosed(final TableToken tableToken, long txn, int walId, int segmentId);
void tableCreated(TableToken tableToken, long timestamp);
......
......@@ -459,7 +459,7 @@ public interface Sender extends Closeable {
* the Sender to load a trust store from a classpath.
*
* @param trustStorePath a path to a trust store.
* @param trustStorePassword a password to for the trustore
* @param trustStorePassword a password to for the truststore
* @return an instance of LineSenderBuilder for further configuration
*/
public LineSenderBuilder customTrustStore(String trustStorePath, char[] trustStorePassword) {
......
......@@ -24,6 +24,9 @@
package io.questdb.cutlass.auth;
import io.questdb.network.Socket;
import org.jetbrains.annotations.NotNull;
public class AnonymousAuthenticator implements Authenticator {
public static final AnonymousAuthenticator INSTANCE = new AnonymousAuthenticator();
......@@ -49,7 +52,7 @@ public class AnonymousAuthenticator implements Authenticator {
}
@Override
public void init(int fd, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit) {
public void init(@NotNull Socket socket, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit) {
}
@Override
......
......@@ -24,7 +24,9 @@
package io.questdb.cutlass.auth;
import io.questdb.network.Socket;
import io.questdb.std.QuietCloseable;
import org.jetbrains.annotations.NotNull;
public interface Authenticator extends QuietCloseable {
......@@ -49,7 +51,7 @@ public interface Authenticator extends QuietCloseable {
int handleIO() throws AuthenticatorException;
void init(int fd, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit);
void init(@NotNull Socket socket, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit);
boolean isAuthenticated();
}
......@@ -25,22 +25,19 @@
package io.questdb.cutlass.auth;
import io.questdb.cutlass.line.tcp.auth.EllipticCurveAuthenticator;
import io.questdb.network.NetworkFacade;
import io.questdb.std.ObjectFactory;
public class EllipticCurveAuthenticatorFactory implements LineAuthenticatorFactory {
private final ObjectFactory<? extends ChallengeResponseMatcher> matcherFactory;
private final NetworkFacade networkFacade;
public EllipticCurveAuthenticatorFactory(NetworkFacade networkFacade, ObjectFactory<? extends ChallengeResponseMatcher> matcherFactory) {
this.networkFacade = networkFacade;
public EllipticCurveAuthenticatorFactory(ObjectFactory<? extends ChallengeResponseMatcher> matcherFactory) {
this.matcherFactory = matcherFactory;
}
@Override
public Authenticator getLineTCPAuthenticator() {
return new EllipticCurveAuthenticator(
networkFacade,
matcherFactory.newInstance());
matcherFactory.newInstance()
);
}
}
......@@ -104,6 +104,11 @@ public class DefaultHttpServerConfiguration implements HttpServerConfiguration {
return dispatcherConfiguration;
}
@Override
public FactoryProvider getFactoryProvider() {
return DefaultFactoryProvider.INSTANCE;
}
@Override
public HttpContextConfiguration getHttpContextConfiguration() {
return httpContextConfiguration;
......
......@@ -24,6 +24,7 @@
package io.questdb.cutlass.http;
import io.questdb.FactoryProvider;
import io.questdb.mp.WorkerPoolConfiguration;
import io.questdb.network.IODispatcherConfiguration;
......@@ -31,6 +32,8 @@ public interface HttpMinServerConfiguration extends WorkerPoolConfiguration {
IODispatcherConfiguration getDispatcherConfiguration();
FactoryProvider getFactoryProvider();
HttpContextConfiguration getHttpContextConfiguration();
WaitProcessorConfiguration getWaitProcessorConfiguration();
......
......@@ -58,11 +58,11 @@ public class HttpResponseSink implements Closeable, Mutable {
private boolean compressionComplete;
private int crc = 0;
private boolean deflateBeforeSend = false;
private int fd;
private boolean headersSent;
private Socket socket;
private long total = 0;
private long totalBytesSent = 0;
private long z_streamp = 0;
private long zStreamPtr = 0;
public HttpResponseSink(HttpContextConfiguration configuration) {
final int responseBufferSize = Numbers.ceilPow2(configuration.getSendBufferSize());
......@@ -86,13 +86,13 @@ public class HttpResponseSink implements Closeable, Mutable {
@Override
public void close() {
if (z_streamp != 0) {
Zip.deflateEnd(z_streamp);
z_streamp = 0;
if (zStreamPtr != 0) {
Zip.deflateEnd(zStreamPtr);
zStreamPtr = 0;
compressOutBuffer.close();
}
buffer.close();
fd = -1;
socket = null;
}
public HttpChunkedResponseSocket getChunkedSocket() {
......@@ -132,8 +132,8 @@ public class HttpResponseSink implements Closeable, Mutable {
public void setDeflateBeforeSend(boolean deflateBeforeSend) {
this.deflateBeforeSend = deflateBeforeSend;
if (z_streamp == 0 && deflateBeforeSend) {
z_streamp = Zip.deflateInit();
if (zStreamPtr == 0 && deflateBeforeSend) {
zStreamPtr = Zip.deflateInit();
compressOutBuffer.reopen();
}
}
......@@ -149,9 +149,9 @@ public class HttpResponseSink implements Closeable, Mutable {
int nInAvailable = (int) buffer.getReadNAvailable();
if (nInAvailable > 0) {
long inAddress = buffer.getReadAddress();
LOG.debug().$("Zip.setInput [inAddress=").$(inAddress).$(", nInAvailable=").$(nInAvailable).$(']').$();
LOG.debug().$("Zip.setInput [inAddress=").$(inAddress).$(", nInAvailable=").$(nInAvailable).I$();
buffer.write64BitZeroPadding();
Zip.setInput(z_streamp, inAddress, nInAvailable);
Zip.setInput(zStreamPtr, inAddress, nInAvailable);
}
int ret;
......@@ -160,9 +160,9 @@ public class HttpResponseSink implements Closeable, Mutable {
do {
int sz = (int) compressOutBuffer.getWriteNAvailable() - 8;
long p = compressOutBuffer.getWriteAddress(0);
LOG.debug().$("deflate starting [p=").$(p).$(", sz=").$(sz).$(", chunkedRequestDone=").$(chunkedRequestDone).$(']').$();
ret = Zip.deflate(z_streamp, p, sz, chunkedRequestDone);
len = sz - Zip.availOut(z_streamp);
LOG.debug().$("deflate starting [p=").$(p).$(", sz=").$(sz).$(", chunkedRequestDone=").$(chunkedRequestDone).I$();
ret = Zip.deflate(zStreamPtr, p, sz, chunkedRequestDone);
len = sz - Zip.availOut(zStreamPtr);
compressOutBuffer.onWrite(len);
if (ret < 0) {
// This is not an error, zlib just couldn't do any work with the input/output buffers it was provided.
......@@ -173,7 +173,7 @@ public class HttpResponseSink implements Closeable, Mutable {
}
}
int availIn = Zip.availIn(z_streamp);
int availIn = Zip.availIn(zStreamPtr);
int nInConsumed = nInAvailable - availIn;
if (nInConsumed > 0) {
this.crc = Zip.crc32(this.crc, buffer.getReadAddress(), nInConsumed);
......@@ -182,7 +182,7 @@ public class HttpResponseSink implements Closeable, Mutable {
nInAvailable = availIn;
}
LOG.debug().$("deflate finished [ret=").$(ret).$(", len=").$(len).$(", availIn=").$(availIn).$(']').$();
LOG.debug().$("deflate finished [ret=").$(ret).$(", len=").$(len).$(", availIn=").$(availIn).I$();
} while (len == 0 && nInAvailable > 0);
if (nInAvailable == 0) {
......@@ -223,14 +223,18 @@ public class HttpResponseSink implements Closeable, Mutable {
sendBuffer(buffer);
}
private int getFd() {
return socket != null ? socket.getFd() : -1;
}
private void prepareHeaderSink() {
buffer.prepareToReadFromBuffer(false, false);
headerImpl.prepareToSend();
}
private void resetZip() {
if (z_streamp != 0) {
Zip.deflateReset(z_streamp);
if (zStreamPtr != 0) {
Zip.deflateReset(zStreamPtr);
compressOutBuffer.clear();
crc = 0;
total = 0;
......@@ -243,13 +247,13 @@ public class HttpResponseSink implements Closeable, Mutable {
private void sendBuffer(ChunkBuffer sendBuf) throws PeerDisconnectedException, PeerIsSlowToReadException {
int nSend = (int) sendBuf.getReadNAvailable();
while (nSend > 0) {
int n = nf.send(fd, sendBuf.getReadAddress(), nSend);
int n = socket.send(sendBuf.getReadAddress(), nSend);
if (n < 0) {
// disconnected
LOG.error()
.$("disconnected [errno=").$(nf.errno())
.$(", fd=").$(fd)
.$(']').$();
.$(", fd=").$(socket.getFd())
.I$();
throw PeerDisconnectedException.INSTANCE;
}
if (n == 0) {
......@@ -278,9 +282,9 @@ public class HttpResponseSink implements Closeable, Mutable {
return totalBytesSent;
}
void of(int fd) {
this.fd = fd;
if (fd > -1) {
void of(Socket socket) {
this.socket = socket;
if (socket != null) {
this.buffer.reopen();
}
}
......@@ -399,7 +403,7 @@ public class HttpResponseSink implements Closeable, Mutable {
int len = EOF_CHUNK.length();
Chars.asciiStrCpy(EOF_CHUNK, len, _wptr);
_wptr += len;
LOG.debug().$("end chunk sent [fd=").$(fd).$(']').$();
LOG.debug().$("end chunk sent [fd=").$(getFd()).I$();
}
}
......@@ -457,7 +461,7 @@ public class HttpResponseSink implements Closeable, Mutable {
@Override
public void shutdownWrite() {
nf.shutdown(fd, Net.SHUT_WR);
socket.shutdown(Net.SHUT_WR);
}
@Override
......
......@@ -63,10 +63,7 @@ public class HttpServer implements Closeable {
}
this.httpContextFactory = new HttpContextFactory(configuration.getHttpContextConfiguration(), metrics);
this.dispatcher = IODispatchers.create(
configuration.getDispatcherConfiguration(),
httpContextFactory
);
this.dispatcher = IODispatchers.create(configuration.getDispatcherConfiguration(), httpContextFactory);
pool.assign(dispatcher);
this.rescheduleContext = new WaitProcessor(configuration.getWaitProcessorConfiguration());
pool.assign(this.rescheduleContext);
......
......@@ -214,9 +214,4 @@ public class DefaultLineTcpReceiverConfiguration implements LineTcpReceiverConfi
public boolean isSymbolAsFieldSupported() {
return false;
}
@Override
public boolean readOnlySecurityContext() {
return false;
}
}
......@@ -42,6 +42,7 @@ import io.questdb.std.*;
import io.questdb.std.datetime.millitime.MillisecondClock;
import io.questdb.std.str.ByteCharSequence;
import io.questdb.std.str.DirectByteCharSequence;
import org.jetbrains.annotations.NotNull;
public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext> {
private static final Log LOG = LogFactory.getLog(LineTcpConnectionContext.class);
......@@ -71,6 +72,7 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
private long nextCommitTime;
public LineTcpConnectionContext(LineTcpReceiverConfiguration configuration, LineTcpMeasurementScheduler scheduler, Metrics metrics) {
super(configuration.getFactoryProvider().getLineSocketFactory(), configuration.getNetworkFacade(), LOG);
this.configuration = configuration;
nf = configuration.getNetworkFacade();
disconnectOnError = configuration.getDisconnectOnError();
......@@ -78,8 +80,6 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
this.metrics = metrics;
this.milliClock = configuration.getMillisecondClock();
parser = new LineTcpParser(configuration.isStringAsTagSupported(), configuration.isSymbolAsFieldSupported());
recvBufStart = Unsafe.malloc(configuration.getNetMsgBufferSize(), MemoryTag.NATIVE_ILP_RSS);
recvBufEnd = recvBufStart + configuration.getNetMsgBufferSize();
this.authenticator = configuration.getFactoryProvider().getLineAuthenticatorFactory().getLineTCPAuthenticator();
clear();
this.checkIdleInterval = configuration.getMaintenanceInterval();
......@@ -103,9 +103,10 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
@Override
public void clear() {
super.clear();
securityContext = DenyAllSecurityContext.INSTANCE;
authenticator.clear();
recvBufPos = recvBufStart;
recvBufStart = recvBufEnd = recvBufPos = Unsafe.free(recvBufStart, recvBufEnd - recvBufStart, MemoryTag.NATIVE_ILP_RSS);
peerDisconnected = false;
resetParser();
ObjList<ByteCharSequence> keys = tableUpdateDetailsUtf8.keys();
......@@ -119,10 +120,8 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
@Override
public void close() {
this.fd = -1;
recvBufStart = recvBufEnd = recvBufPos = Unsafe.free(recvBufStart, recvBufEnd - recvBufStart, MemoryTag.NATIVE_ILP_RSS);
Misc.free(authenticator);
clear();
Misc.free(authenticator);
}
public long commitWalTables(long wallClockMillis) {
......@@ -191,14 +190,30 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
}
@Override
public LineTcpConnectionContext of(int fd, IODispatcher<LineTcpConnectionContext> dispatcher) {
authenticator.init(fd, recvBufStart, recvBufEnd, 0, 0);
public void init() {
if (socket.supportsTls()) {
if (socket.startTlsSession() != 0) {
throw CairoException.nonCritical().put("failed to start TLS session");
}
}
}
@Override
public LineTcpConnectionContext of(int fd, @NotNull IODispatcher<LineTcpConnectionContext> dispatcher) {
super.of(fd, dispatcher);
if (recvBufStart == 0) {
recvBufStart = Unsafe.malloc(configuration.getNetMsgBufferSize(), MemoryTag.NATIVE_ILP_RSS);
recvBufEnd = recvBufStart + configuration.getNetMsgBufferSize();
recvBufPos = recvBufStart;
resetParser();
}
authenticator.init(socket, recvBufStart, recvBufEnd, 0, 0);
if (authenticator.isAuthenticated() && securityContext == DenyAllSecurityContext.INSTANCE) {
// when security context has not been set by anything else (subclass) we assume
// this is an authenticated, anonymous user
securityContext = configuration.getFactoryProvider().getSecurityContextFactory().getInstance(null, SecurityContextFactory.ILP);
}
return super.of(fd, dispatcher);
return this;
}
private boolean checkQueueFullLogHysteresis() {
......@@ -212,16 +227,17 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
private void doHandleDisconnectEvent() {
if (parser.getBufferAddress() == recvBufEnd) {
LOG.error().$('[').$(fd).$("] buffer overflow [line.tcp.msg.buffer.size=").$(recvBufEnd - recvBufStart).$(']').$();
LOG.error().$('[').$(getFd()).$("] buffer overflow [line.tcp.msg.buffer.size=").$(recvBufEnd - recvBufStart).$(']').$();
return;
}
if (peerDisconnected) {
// Peer disconnected, we have now finished disconnect our end
if (recvBufPos != recvBufStart) {
LOG.info().$('[').$(fd).$("] peer disconnected with partial measurement, ").$(recvBufPos - recvBufStart).$(" unprocessed bytes").$();
LOG.info().$('[').$(getFd()).$("] peer disconnected with partial measurement, ").$(recvBufPos - recvBufStart)
.$(" unprocessed bytes").$();
} else {
LOG.info().$('[').$(fd).$("] peer disconnected").$();
LOG.info().$('[').$(getFd()).$("] peer disconnected").$();
}
}
}
......@@ -258,7 +274,7 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
int position = (int) (parser.getBufferAddress() - recvBufStartOfMeasurement);
assert position >= 0;
LOG.error()
.$('[').$(fd)
.$('[').$(getFd())
.$("] could not parse measurement, ").$(parser.getErrorCode())
.$(" at ").$(position)
.$(", line (may be mangled due to partial parsing): '")
......@@ -324,7 +340,7 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
if (scheduler.scheduleEvent(getSecurityContext(), netIoJob, this, parser)) {
// Waiting for writer threads to drain queue, request callback as soon as possible
if (checkQueueFullLogHysteresis()) {
LOG.debug().$('[').$(fd).$("] queue full").$();
LOG.debug().$('[').$(getFd()).$("] queue full").$();
}
return IOContextResult.QUEUE_FULL;
}
......@@ -334,7 +350,6 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
}
startNewMeasurement();
continue;
}
......@@ -361,7 +376,7 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
}
} catch (CairoException ex) {
LOG.error()
.$('[').$(fd).$("] could not process line data [table=").$(parser.getMeasurementName())
.$('[').$(getFd()).$("] could not process line data [table=").$(parser.getMeasurementName())
.$(", msg=").$(ex.getFlyweightMessage())
.$(", errno=").$(ex.getErrno())
.I$();
......@@ -374,7 +389,7 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
goodMeasurement = false;
} catch (Throwable ex) {
LOG.critical()
.$('[').$(fd).$("] could not process line data [table=").$(parser.getMeasurementName())
.$('[').$(getFd()).$("] could not process line data [table=").$(parser.getMeasurementName())
.$(", ex=").$(ex)
.I$();
// This is a critical error, so we treat it as an unhandled one.
......@@ -388,7 +403,7 @@ public class LineTcpConnectionContext extends IOContext<LineTcpConnectionContext
int bufferRemaining = (int) (recvBufEnd - recvBufPos);
final int orig = bufferRemaining;
if (bufferRemaining > 0 && !peerDisconnected) {
int bytesRead = nf.recv(fd, recvBufPos, bufferRemaining);
int bytesRead = socket.recv(recvBufPos, bufferRemaining);
if (bytesRead > 0) {
recvBufPos += bytesRead;
bufferRemaining -= bytesRead;
......
......@@ -56,10 +56,7 @@ public class LineTcpReceiver implements Closeable {
factory,
configuration.getConnectionPoolInitialCapacity()
);
this.dispatcher = IODispatchers.create(
configuration.getDispatcherConfiguration(),
contextFactory
);
this.dispatcher = IODispatchers.create(configuration.getDispatcherConfiguration(), contextFactory);
ioWorkerPool.assign(dispatcher);
this.scheduler = new LineTcpMeasurementScheduler(configuration, engine, ioWorkerPool, dispatcher, writerWorkerPool);
......
......@@ -34,6 +34,7 @@ import io.questdb.std.datetime.microtime.MicrosecondClock;
import io.questdb.std.datetime.millitime.MillisecondClock;
public interface LineTcpReceiverConfiguration {
String getAuthDB();
boolean getAutoCreateNewColumns();
......@@ -101,6 +102,4 @@ public interface LineTcpReceiverConfiguration {
boolean isStringToCharCastAllowed();
boolean isSymbolAsFieldSupported();
boolean readOnlySecurityContext();
}
......@@ -93,7 +93,7 @@ public final class PlainTcpLineChannel implements LineChannel {
@Override
public int receive(long ptr, int len) {
return nf.recv(fd, ptr, len);
return nf.recvRaw(fd, ptr, len);
}
@Override
......@@ -101,7 +101,7 @@ public final class PlainTcpLineChannel implements LineChannel {
if (len > 0) {
long o = 0;
while (len > 0) {
int n = nf.send(fd, ptr + o, len);
int n = nf.sendRaw(fd, ptr + o, len);
if (n > 0) {
len -= n;
o += n;
......
......@@ -31,12 +31,13 @@ import io.questdb.cutlass.auth.AuthenticatorException;
import io.questdb.cutlass.auth.ChallengeResponseMatcher;
import io.questdb.log.Log;
import io.questdb.log.LogFactory;
import io.questdb.network.NetworkFacade;
import io.questdb.network.Socket;
import io.questdb.std.Chars;
import io.questdb.std.MemoryTag;
import io.questdb.std.ThreadLocal;
import io.questdb.std.Unsafe;
import io.questdb.std.str.DirectByteCharSequence;
import org.jetbrains.annotations.NotNull;
import java.security.SecureRandom;
......@@ -46,20 +47,18 @@ public class EllipticCurveAuthenticator implements Authenticator {
private static final ThreadLocal<SecureRandom> tlSrand = new ThreadLocal<>(SecureRandom::new);
private final ChallengeResponseMatcher challengeResponseMatcher;
private final NetworkFacade nf;
private final DirectByteCharSequence userNameFlyweight = new DirectByteCharSequence();
protected long recvBufPseudoStart;
private AuthState authState;
private long challengePtr;
private int fd;
private String principal;
private long recvBufEnd;
private long recvBufPos;
private long recvBufStart;
private Socket socket;
public EllipticCurveAuthenticator(NetworkFacade networkFacade, ChallengeResponseMatcher challengeResponseMatcher) {
public EllipticCurveAuthenticator(ChallengeResponseMatcher challengeResponseMatcher) {
this.challengeResponseMatcher = challengeResponseMatcher;
this.nf = networkFacade;
this.challengePtr = Unsafe.malloc(AuthUtils.CHALLENGE_LEN, MemoryTag.NATIVE_DEFAULT);
}
......@@ -105,11 +104,11 @@ public class EllipticCurveAuthenticator implements Authenticator {
}
@Override
public void init(int fd, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit) {
public void init(@NotNull Socket socket, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit) {
if (recvBufferLimit - recvBuffer < MIN_BUF_SIZE) {
throw CairoException.critical(0).put("Minimum buffer length is ").put(MIN_BUF_SIZE);
}
this.fd = fd;
this.socket = socket;
authState = AuthState.WAITING_FOR_KEY_ID;
this.recvBufStart = recvBuffer;
this.recvBufPos = recvBuffer;
......@@ -124,11 +123,11 @@ public class EllipticCurveAuthenticator implements Authenticator {
private int findLineEnd() throws AuthenticatorException {
int bufferRemaining = (int) (recvBufEnd - recvBufPos);
if (bufferRemaining > 0) {
int bytesRead = nf.recv(fd, recvBufPos, bufferRemaining);
int bytesRead = socket.recv(recvBufPos, bufferRemaining);
if (bytesRead > 0) {
recvBufPos += bytesRead;
} else if (bytesRead < 0) {
LOG.info().$('[').$(fd).$("] authentication disconnected by peer when reading token").$();
LOG.info().$('[').$(socket.getFd()).$("] authentication disconnected by peer when reading token").$();
throw AuthenticatorException.INSTANCE;
}
}
......@@ -149,7 +148,7 @@ public class EllipticCurveAuthenticator implements Authenticator {
}
if (recvBufPos == recvBufEnd) {
LOG.info().$('[').$(fd).$("] authentication token is too long").$();
LOG.info().$('[').$(socket.getFd()).$("] authentication token is too long").$();
throw AuthenticatorException.INSTANCE;
}
......@@ -161,7 +160,7 @@ public class EllipticCurveAuthenticator implements Authenticator {
if (lineEnd != -1) {
userNameFlyweight.of(recvBufStart, recvBufStart + lineEnd);
principal = Chars.toString(userNameFlyweight);
LOG.info().$('[').$(fd).$("] authentication read key id [keyId=").$(userNameFlyweight).$(']').$();
LOG.info().$('[').$(socket.getFd()).$("] authentication read key id [keyId=").$(userNameFlyweight).I$();
recvBufPos = recvBufStart;
// Generate a challenge with printable ASCII characters 0x20 to 0x7e
int n = 0;
......@@ -182,7 +181,7 @@ public class EllipticCurveAuthenticator implements Authenticator {
int n = AuthUtils.CHALLENGE_LEN + 1 - (int) (recvBufPos - recvBufStart);
assert n > 0;
while (true) {
int nWritten = nf.send(fd, recvBufPos, n);
int nWritten = socket.send(recvBufPos, n);
if (nWritten > 0) {
if (n == nWritten) {
recvBufPos = recvBufStart;
......@@ -199,7 +198,7 @@ public class EllipticCurveAuthenticator implements Authenticator {
break;
}
LOG.info().$('[').$(fd).$("] authentication peer disconnected when challenge was being sent").$();
LOG.info().$('[').$(socket.getFd()).$("] authentication peer disconnected when challenge was being sent").$();
throw AuthenticatorException.INSTANCE;
}
......@@ -208,18 +207,18 @@ public class EllipticCurveAuthenticator implements Authenticator {
if (lineEnd != -1) {
// Verify signature
if (lineEnd > AuthUtils.MAX_SIGNATURE_LENGTH_BASE64) {
LOG.info().$('[').$(fd).$("] authentication signature is too long").$();
LOG.info().$('[').$(socket.getFd()).$("] authentication signature is too long").$();
throw AuthenticatorException.INSTANCE;
}
authState = AuthState.FAILED;
boolean verified = challengeResponseMatcher.verifyLineToken(principal, challengePtr, AuthUtils.CHALLENGE_LEN, recvBufStart, lineEnd);
if (!verified) {
LOG.info().$('[').$(fd).$("] authentication failed, signature was not verified").$();
LOG.info().$('[').$(socket.getFd()).$("] authentication failed, signature was not verified").$();
throw AuthenticatorException.INSTANCE;
}
authState = AuthState.COMPLETE;
LOG.info().$('[').$(fd).$("] authentication success").$();
LOG.info().$('[').$(socket.getFd()).$("] authentication success").$();
}
return lineEnd;
}
......@@ -237,5 +236,4 @@ public class EllipticCurveAuthenticator implements Authenticator {
this.ioContextResult = ioContextResult;
}
}
}
......@@ -56,7 +56,7 @@ public class LineUdpReceiver extends AbstractLineProtoUdpReceiver {
protected boolean runSerially() {
boolean ran = false;
int count;
while ((count = nf.recv(fd, buf, bufLen)) > 0) {
while ((count = nf.recvRaw(fd, buf, bufLen)) > 0) {
lexer.parse(buf, buf + count);
lexer.parseLast();
......
......@@ -56,7 +56,7 @@ public class LinuxMMLineUdpReceiver extends AbstractLineProtoUdpReceiver {
protected boolean runSerially() {
boolean ran = false;
int count;
while ((count = nf.recvmmsg(fd, msgVec, msgCount)) > 0) {
while ((count = nf.recvmmsgRaw(fd, msgVec, msgCount)) > 0) {
long p = msgVec;
for (int i = 0; i < count; i++) {
long buf = nf.getMMsgBuf(p);
......
......@@ -81,7 +81,7 @@ public final class UdpLineChannel implements LineChannel {
@Override
public void send(long ptr, int len) {
if (nf.sendTo(fd, ptr, len, sockaddr) != len) {
if (nf.sendToRaw(fd, ptr, len, sockaddr) != len) {
throw new LineSenderException("send error").errno(nf.errno());
}
}
......
......@@ -33,12 +33,13 @@ import io.questdb.griffin.CharacterStore;
import io.questdb.griffin.CharacterStoreEntry;
import io.questdb.log.Log;
import io.questdb.log.LogFactory;
import io.questdb.network.NetworkFacade;
import io.questdb.network.NoSpaceLeftInResponseBufferException;
import io.questdb.network.Socket;
import io.questdb.std.*;
import io.questdb.std.str.AbstractCharSink;
import io.questdb.std.str.CharSink;
import io.questdb.std.str.DirectByteCharSequence;
import org.jetbrains.annotations.NotNull;
public final class CleartextPasswordPgWireAuthenticator implements Authenticator {
public static final char STATUS_IDLE = 'I';
......@@ -57,12 +58,10 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
private final int circuitBreakerId;
private final DirectByteCharSequence dbcs = new DirectByteCharSequence();
private final boolean matcherOwned;
private final NetworkFacade nf;
private final OptionsListener optionsListener;
private final CircuitBreakerRegistry registry;
private final String serverVersion;
private final ResponseSink sink;
private int fd;
private UsernamePasswordMatcher matcher;
private long recvBufEnd;
private long recvBufReadPos;
......@@ -72,11 +71,11 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
private long sendBufReadPos;
private long sendBufStart;
private long sendBufWritePos;
private Socket socket;
private State state = State.EXPECT_INIT_MESSAGE;
private CharSequence username;
public CleartextPasswordPgWireAuthenticator(
NetworkFacade nf,
PGWireConfiguration configuration,
NetworkSqlExecutionCircuitBreaker circuitBreaker,
CircuitBreakerRegistry registry,
......@@ -86,7 +85,6 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
) {
this.matcher = matcher;
this.matcherOwned = matcherOwned;
this.nf = nf;
this.characterStore = new CharacterStore(
configuration.getCharacterStoreCapacity(),
configuration.getCharacterStorePoolCapacity()
......@@ -101,6 +99,7 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
@Override
public void clear() {
circuitBreaker.setSecret(-1);
circuitBreaker.resetMaxTimeToDefault();
circuitBreaker.unsetTimer();
}
......@@ -183,7 +182,7 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
break;
}
case AUTH_SUCCESS:
circuitBreaker.of(fd);
circuitBreaker.of(socket.getFd());
return Authenticator.OK;
case AUTH_FAILED:
return Authenticator.NEEDS_DISCONNECT;
......@@ -197,15 +196,11 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
}
@Override
public void init(int fd, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit) {
if (fd == -1) {
this.circuitBreaker.setSecret(-1);
} else {
this.circuitBreaker.setSecret(registry.getNewSecret());
}
public void init(@NotNull Socket socket, long recvBuffer, long recvBufferLimit, long sendBuffer, long sendBufferLimit) {
this.circuitBreaker.setSecret(registry.getNewSecret());
this.state = State.EXPECT_INIT_MESSAGE;
this.username = null;
this.fd = fd;
this.socket = socket;
this.recvBufStart = recvBuffer;
this.recvBufReadPos = recvBuffer;
this.recvBufWritePos = recvBuffer;
......@@ -319,14 +314,14 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
// To issue a cancel request, the frontend opens a new connection to the server and sends a CancelRequest message, rather than the StartupMessage message
// that would ordinarily be sent across a new connection. The server will process this request and then close the connection.
// For security reasons, no direct reply is made to the cancel request message.
int pid = getIntUnsafe(recvBufReadPos);//thread id really
int pid = getIntUnsafe(recvBufReadPos); // thread id really
recvBufReadPos += Integer.BYTES;
int secret = getIntUnsafe(recvBufReadPos);
recvBufReadPos += Integer.BYTES;
LOG.info().$("cancel request [pid=").$(pid).I$();
try {
registry.cancel(pid, secret);
} catch (CairoException e) {//error message should not be sent to client
} catch (CairoException e) { // error message should not be sent to client
LOG.error().$(e.getMessage()).$();
}
}
......@@ -446,7 +441,7 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
}
private int readFromSocket() {
int bytesRead = nf.recv(fd, recvBufWritePos, (int) (recvBufEnd - recvBufWritePos));
int bytesRead = socket.recv(recvBufWritePos, (int) (recvBufEnd - recvBufWritePos));
if (bytesRead < 0) {
return Authenticator.NEEDS_DISCONNECT;
}
......@@ -456,7 +451,7 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
private int writeToSocketAndAdvance(State nextState) {
int toWrite = (int) (sendBufWritePos - sendBufReadPos);
int n = nf.send(fd, sendBufReadPos, toWrite);
int n = socket.send(sendBufReadPos, toWrite);
if (n < 0) {
return Authenticator.NEEDS_DISCONNECT;
}
......@@ -466,7 +461,7 @@ public final class CleartextPasswordPgWireAuthenticator implements Authenticator
state = nextState;
return Authenticator.OK;
}
// we could try to call nf.send() again as there could be space in the socket buffer now
// we could try to call socket.send() again as there could be space in the socket buffer now
// but: auth messages are small and we assume that the socket buffer is large enough to accommodate them in one go
// thus this return should be rare and we will just wait for the next select() call
return Authenticator.NEEDS_WRITE;
......
......@@ -27,7 +27,6 @@ package io.questdb.cutlass.pgwire;
import io.questdb.ServerMain;
import io.questdb.cairo.sql.NetworkSqlExecutionCircuitBreaker;
import io.questdb.cutlass.auth.Authenticator;
import io.questdb.network.NetworkFacade;
import io.questdb.std.str.DirectByteCharSink;
public final class DefaultPgWireAuthenticatorFactory implements PgWireAuthenticatorFactory {
......@@ -35,7 +34,6 @@ public final class DefaultPgWireAuthenticatorFactory implements PgWireAuthentica
@Override
public Authenticator getPgWireAuthenticator(
NetworkFacade nf,
PGWireConfiguration configuration,
NetworkSqlExecutionCircuitBreaker circuitBreaker,
CircuitBreakerRegistry registry,
......@@ -57,7 +55,6 @@ public final class DefaultPgWireAuthenticatorFactory implements PgWireAuthentica
);
return new CleartextPasswordPgWireAuthenticator(
nf,
configuration,
circuitBreaker,
registry,
......
......@@ -80,7 +80,6 @@ public class PGJobContext implements Closeable {
int operation
) throws PeerIsSlowToWriteException, PeerIsSlowToReadException, PeerDisconnectedException, QueryPausedException, BadProtocolException {
context.handleClientOperation(
engine,
typesAndSelectCache,
typesAndSelectPool,
typesAndUpdateCache,
......
......@@ -64,10 +64,7 @@ public class PGWireServer implements Closeable {
PGConnectionContextFactory contextFactory,
CircuitBreakerRegistry registry
) {
this.dispatcher = IODispatchers.create(
configuration.getDispatcherConfiguration(),
contextFactory
);
this.dispatcher = IODispatchers.create(configuration.getDispatcherConfiguration(), contextFactory);
this.metrics = engine.getMetrics();
this.workerPool = workerPool;
this.registry = registry;
......@@ -163,20 +160,30 @@ public class PGWireServer implements Closeable {
CircuitBreakerRegistry registry,
ObjectFactory<SqlExecutionContextImpl> executionContextObjectFactory
) {
super(() -> {
NetworkSqlExecutionCircuitBreaker circuitBreaker = new NetworkSqlExecutionCircuitBreaker(configuration.getCircuitBreakerConfiguration(), MemoryTag.NATIVE_CB5);
PGConnectionContext pgConnectionContext = new PGConnectionContext(
engine,
configuration,
executionContextObjectFactory.newInstance(),
circuitBreaker
);
FactoryProvider factoryProvider = configuration.getFactoryProvider();
NetworkFacade nf = configuration.getNetworkFacade();
Authenticator authenticator = factoryProvider.getPgWireAuthenticatorFactory().getPgWireAuthenticator(nf, configuration, circuitBreaker, registry, pgConnectionContext);
pgConnectionContext.setAuthenticator(authenticator);
return pgConnectionContext;
}, configuration.getConnectionPoolInitialCapacity());
super(
() -> {
NetworkSqlExecutionCircuitBreaker circuitBreaker = new NetworkSqlExecutionCircuitBreaker(
configuration.getCircuitBreakerConfiguration(),
MemoryTag.NATIVE_CB5
);
PGConnectionContext pgConnectionContext = new PGConnectionContext(
engine,
configuration,
executionContextObjectFactory.newInstance(),
circuitBreaker
);
FactoryProvider factoryProvider = configuration.getFactoryProvider();
Authenticator authenticator = factoryProvider.getPgWireAuthenticatorFactory().getPgWireAuthenticator(
configuration,
circuitBreaker,
registry,
pgConnectionContext
);
pgConnectionContext.setAuthenticator(authenticator);
return pgConnectionContext;
},
configuration.getConnectionPoolInitialCapacity()
);
}
}
}
......@@ -26,10 +26,12 @@ package io.questdb.cutlass.pgwire;
import io.questdb.cairo.sql.NetworkSqlExecutionCircuitBreaker;
import io.questdb.cutlass.auth.Authenticator;
import io.questdb.network.NetworkFacade;
public interface PgWireAuthenticatorFactory {
Authenticator getPgWireAuthenticator(NetworkFacade nf, PGWireConfiguration configuration,
NetworkSqlExecutionCircuitBreaker circuitBreaker, CircuitBreakerRegistry registry,
OptionsListener optionsListener);
Authenticator getPgWireAuthenticator(
PGWireConfiguration configuration,
NetworkSqlExecutionCircuitBreaker circuitBreaker,
CircuitBreakerRegistry registry,
OptionsListener optionsListener
);
}
......@@ -2,7 +2,6 @@ package io.questdb.cutlass.pgwire;
import io.questdb.cairo.sql.NetworkSqlExecutionCircuitBreaker;
import io.questdb.cutlass.auth.Authenticator;
import io.questdb.network.NetworkFacade;
public class UsernamePasswordPgWireAuthenticatorFactory implements PgWireAuthenticatorFactory {
......@@ -14,12 +13,11 @@ public class UsernamePasswordPgWireAuthenticatorFactory implements PgWireAuthent
@Override
public Authenticator getPgWireAuthenticator(
NetworkFacade nf,
PGWireConfiguration configuration,
NetworkSqlExecutionCircuitBreaker circuitBreaker,
CircuitBreakerRegistry registry,
OptionsListener optionsListener
) {
return new CleartextPasswordPgWireAuthenticator(nf, configuration, circuitBreaker, registry, optionsListener, matcher, false);
return new CleartextPasswordPgWireAuthenticator(configuration, circuitBreaker, registry, optionsListener, matcher, false);
}
}
......@@ -252,7 +252,7 @@ public class LogAlertSocket implements Closeable {
long p = outBufferPtr;
boolean sendFail = false;
while (remaining > 0) {
int n = nf.send(socketFd, p, remaining);
int n = nf.sendRaw(socketFd, p, remaining);
if (n > 0) {
remaining -= n;
p += n;
......@@ -269,7 +269,7 @@ public class LogAlertSocket implements Closeable {
if (!sendFail) {
// receive ack
p = inBufferPtr;
final int n = nf.recv(socketFd, p, inBufferSize);
final int n = nf.recvRaw(socketFd, p, inBufferSize);
if (n > 0) {
logResponse(n);
break;
......@@ -287,7 +287,8 @@ public class LogAlertSocket implements Closeable {
$alertHost(
alertHostIdx,
log.info().$("Failing over from")
).$(" to"));
).$(" to")
);
if (alertHostIdx == this.alertHostIdx) {
logFailOver.$(" with a delay of ")
.$(reconnectDelay / 1000000)
......@@ -393,7 +394,8 @@ public class LogAlertSocket implements Closeable {
throw new LogError(String.format(
"Unexpected ':' found at position %d: %s",
i,
alertTargets));
alertTargets
));
}
portIdx = i;
break;
......
......@@ -24,6 +24,7 @@
package io.questdb.network;
import io.questdb.cairo.CairoException;
import io.questdb.log.Log;
import io.questdb.log.LogFactory;
import io.questdb.mp.*;
......@@ -32,11 +33,20 @@ import io.questdb.std.datetime.millitime.MillisecondClock;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Base class for all I/O dispatchers.
* <p>
* Important invariant:
* dispatcher should never process a fd concurrently with I/O context. Instead, each of them does whatever
* it has to do with a fd and sends a message over an in-memory queue to tell the other party that it's
* free to proceed.
*/
public abstract class AbstractIODispatcher<C extends IOContext<C>> extends SynchronizedJob implements IODispatcher<C>, EagerThreadSetup {
protected static final int DISCONNECT_SRC_IDLE = 1;
protected static final int DISCONNECT_SRC_PEER_DISCONNECT = 3;
protected static final int DISCONNECT_SRC_QUEUE = 0;
protected static final int DISCONNECT_SRC_SHUTDOWN = 2;
protected static final int DISCONNECT_SRC_TLS_ERROR = 4;
protected static final int OPM_CREATE_TIMESTAMP = 0;
protected static final int OPM_FD = 1;
protected static final int OPM_HEARTBEAT_TIMESTAMP = 3;
......@@ -87,22 +97,22 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
this.nf = configuration.getNetworkFacade();
this.testConnectionBufSize = configuration.getTestConnectionBufferSize();
this.testConnectionBuf = Unsafe.malloc(this.testConnectionBufSize, MemoryTag.NATIVE_DEFAULT);
this.testConnectionBuf = Unsafe.malloc(testConnectionBufSize, MemoryTag.NATIVE_DEFAULT);
this.interestQueue = new RingQueue<>(IOEvent::new, configuration.getInterestQueueCapacity());
this.interestPubSeq = new MPSequence(interestQueue.getCycle());
this.interestSubSeq = new SCSequence();
this.interestPubSeq.then(this.interestSubSeq).then(this.interestPubSeq);
this.interestPubSeq.then(interestSubSeq).then(interestPubSeq);
this.ioEventQueue = new RingQueue<>(IOEvent::new, configuration.getIOQueueCapacity());
this.ioEventPubSeq = new SPSequence(configuration.getIOQueueCapacity());
this.ioEventSubSeq = new MCSequence(configuration.getIOQueueCapacity());
this.ioEventPubSeq.then(this.ioEventSubSeq).then(this.ioEventPubSeq);
this.ioEventPubSeq.then(ioEventSubSeq).then(ioEventPubSeq);
this.disconnectQueue = new RingQueue<>(IOEvent::new, configuration.getIOQueueCapacity());
this.disconnectPubSeq = new MPSequence(disconnectQueue.getCycle());
this.disconnectSubSeq = new SCSequence();
this.disconnectPubSeq.then(this.disconnectSubSeq).then(this.disconnectPubSeq);
this.disconnectPubSeq.then(disconnectSubSeq).then(disconnectPubSeq);
this.clock = configuration.getClock();
this.activeConnectionLimit = configuration.getLimit();
......@@ -212,13 +222,21 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
private void addPending(int fd, long timestamp) {
// append pending connection
// all rows below watermark will be registered with epoll (or similar)
final C context = ioContextFactory.newInstance(fd, this);
try {
context.init();
} catch (CairoException e) {
LOG.error().$("could not initialize connection context [fd=").$(fd).$(", e=").$(e.getFlyweightMessage()).I$();
ioContextFactory.done(context);
return;
}
int r = pending.addRow();
LOG.debug().$("pending [row=").$(r).$(", fd=").$(fd).$(']').$();
LOG.debug().$("pending [row=").$(r).$(", fd=").$(fd).I$();
pending.set(r, OPM_CREATE_TIMESTAMP, timestamp);
pending.set(r, OPM_HEARTBEAT_TIMESTAMP, timestamp);
pending.set(r, OPM_FD, fd);
pending.set(r, OPM_OPERATION, -1);
pending.set(r, ioContextFactory.newInstance(fd, this));
pending.set(r, context);
pendingAdded(r);
}
......@@ -242,7 +260,8 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
throw NetworkError.instance(nf.errno()).couldNotBindSocket(
configuration.getDispatcherLogName(),
configuration.getBindIPv4Address(),
this.port);
this.port
);
}
LOG.advisory().$("listening on ").$ip(configuration.getBindIPv4Address()).$(':').$(configuration.getBindPort())
.$(" [fd=").$(serverFd)
......@@ -266,6 +285,15 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
doDisconnect(context, DISCONNECT_SRC_QUEUE);
}
protected static int tlsIOFlags(int requestedOp, boolean readyForRead, boolean readyForWrite) {
return (requestedOp == IOOperation.READ && readyForWrite ? Socket.WRITE_FLAG : 0)
| (requestedOp == IOOperation.WRITE && readyForRead ? Socket.READ_FLAG : 0);
}
protected static int tlsIOFlags(boolean readyForRead, boolean readyForWrite) {
return (readyForWrite ? Socket.WRITE_FLAG : 0) | (readyForRead ? Socket.READ_FLAG : 0);
}
protected void accept(long timestamp) {
int tlConCount = this.connectionCount.get();
while (tlConCount < activeConnectionLimit) {
......@@ -277,13 +305,13 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
if (fd < 0) {
if (nf.errno() != Net.EWOULDBLOCK) {
LOG.error().$("could not accept [ret=").$(fd).$(", errno=").$(nf.errno()).$(']').$();
LOG.error().$("could not accept [ret=").$(fd).$(", errno=").$(nf.errno()).I$();
}
break;
}
if (nf.configureNonBlocking(fd) < 0) {
LOG.error().$("could not configure non-blocking [fd=").$(fd).$(", errno=").$(nf.errno()).$(']').$();
LOG.error().$("could not configure non-blocking [fd=").$(fd).$(", errno=").$(nf.errno()).I$();
nf.close(fd, LOG);
break;
}
......@@ -291,7 +319,7 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
if (nf.setTcpNoDelay(fd, true) < 0) {
// Randomly on OS X, if a client connects and the peer TCP socket has SO_LINGER set to false, then setting the TCP_NODELAY
// option fails!
LOG.info().$("could not turn off Nagle's algorithm [fd=").$(fd).$(", errno=").$(nf.errno()).$(']').$();
LOG.info().$("could not turn off Nagle's algorithm [fd=").$(fd).$(", errno=").$(nf.errno()).I$();
}
if (peerNoLinger) {
......@@ -306,7 +334,7 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
nf.setRcvBuf(fd, rcvBufSize);
}
LOG.info().$("connected [ip=").$ip(nf.getPeerIP(fd)).$(", fd=").$(fd).$(']').$();
LOG.info().$("connected [ip=").$ip(nf.getPeerIP(fd)).$(", fd=").$(fd).I$();
tlConCount = connectionCount.incrementAndGet();
addPending(fd, timestamp);
}
......@@ -332,7 +360,6 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
.$(", fd=").$(fd)
.$(", src=").$(DISCONNECT_SOURCES[src])
.I$();
nf.close(fd, LOG);
if (closed) {
Misc.free(context);
} else {
......@@ -382,6 +409,6 @@ public abstract class AbstractIODispatcher<C extends IOContext<C>> extends Synch
protected abstract void unregisterListenerFd();
static {
DISCONNECT_SOURCES = new String[]{"queue", "idle", "shutdown", "peer"};
DISCONNECT_SOURCES = new String[]{"queue", "idle", "shutdown", "peer", "tls_error"};
}
}
......@@ -55,11 +55,6 @@ public class DefaultIODispatcherConfiguration implements IODispatcherConfigurati
return -1L;
}
@Override
public int getInitialBias() {
return BIAS_READ;
}
@Override
public KqueueFacade getKqueueFacade() {
return KqueueFacadeImpl.INSTANCE;
......
......@@ -24,25 +24,34 @@
package io.questdb.network;
import io.questdb.log.Log;
import io.questdb.std.Mutable;
import io.questdb.std.QuietCloseable;
import org.jetbrains.annotations.NotNull;
public abstract class IOContext<T extends IOContext<T>> implements Mutable, QuietCloseable {
protected final Socket socket;
protected IODispatcher<T> dispatcher;
protected int fd = -1;
protected long heartbeatId = -1;
protected IOContext(SocketFactory socketFactory, NetworkFacade nf, Log log) {
this.socket = socketFactory.newInstance(nf, log);
}
@Override
public void clear() {
heartbeatId = -1;
fd = -1;
dispatcher = null;
_clear();
}
public void clearSuspendEvent() {
// no-op
}
@Override
public void close() {
_clear();
}
public long getAndResetHeartbeatId() {
long id = heartbeatId;
heartbeatId = -1;
......@@ -54,20 +63,31 @@ public abstract class IOContext<T extends IOContext<T>> implements Mutable, Quie
}
public int getFd() {
return fd;
return socket != null ? socket.getFd() : -1;
}
public Socket getSocket() {
return socket;
}
public SuspendEvent getSuspendEvent() {
return null;
}
/**
* @throws io.questdb.cairo.CairoException if initialization fails
*/
public void init() {
// no-op
}
public boolean invalid() {
return fd == -1;
return socket.getFd() == -1;
}
@SuppressWarnings("unchecked")
public T of(int fd, IODispatcher<T> dispatcher) {
this.fd = fd;
public T of(int fd, @NotNull IODispatcher<T> dispatcher) {
socket.of(fd);
this.dispatcher = dispatcher;
return (T) this;
}
......@@ -75,4 +95,11 @@ public abstract class IOContext<T extends IOContext<T>> implements Mutable, Quie
public void setHeartbeatId(long heartbeatId) {
this.heartbeatId = heartbeatId;
}
private void _clear() {
heartbeatId = -1;
socket.close();
dispatcher = null;
clearSuspendEvent();
}
}
......@@ -29,6 +29,7 @@ import io.questdb.std.Misc;
import io.questdb.std.ObjectFactory;
import io.questdb.std.ThreadLocal;
import io.questdb.std.WeakMutableObjectPool;
import org.jetbrains.annotations.NotNull;
import java.io.Closeable;
......@@ -52,17 +53,16 @@ public class IOContextFactoryImpl<C extends IOContext<C>> implements IOContextFa
if (closed) {
Misc.free(context);
} else {
context.of(-1, null);
contextPool.get().push(context);
}
}
public void freeThreadLocal() {
// helper call, it will free only thread-local instance and not others
Misc.free(this.contextPool);
Misc.free(contextPool);
}
public C newInstance(int fd, IODispatcher<C> dispatcher) {
public C newInstance(int fd, @NotNull IODispatcher<C> dispatcher) {
return contextPool.get().pop().of(fd, dispatcher);
}
......
......@@ -28,7 +28,7 @@ import io.questdb.mp.Job;
import java.io.Closeable;
public interface IODispatcher<C extends IOContext> extends Closeable, Job {
public interface IODispatcher<C extends IOContext<C>> extends Closeable, Job {
int DISCONNECT_REASON_KEEPALIVE_OFF = 1;
int DISCONNECT_REASON_KEEPALIVE_OFF_RECV = 4;
int DISCONNECT_REASON_KICKED_OUT_AT_EXTRA_BYTES = 13;
......
......@@ -58,7 +58,9 @@ public interface IODispatcherConfiguration {
return Numbers.ceilPow2(getLimit());
}
int getInitialBias();
default int getInitialBias() {
return BIAS_READ;
}
default int getInterestQueueCapacity() {
return Numbers.ceilPow2(getLimit());
......
......@@ -78,18 +78,29 @@ public class IODispatcherLinux<C extends IOContext<C>> extends AbstractIODispatc
private void enqueuePending(int watermark) {
for (int i = watermark, sz = pending.size(), offset = 0; i < sz; i++, offset += EpollAccessor.SIZEOF_EVENT) {
final C context = pending.get(i);
final long id = pending.get(i, OPM_ID);
final int fd = (int) pending.get(i, OPM_FD);
int operation = initialBias == IODispatcherConfiguration.BIAS_READ ? IOOperation.READ : IOOperation.WRITE;
final int operation = initialBias == IODispatcherConfiguration.BIAS_READ ? IOOperation.READ : IOOperation.WRITE;
pending.set(i, OPM_OPERATION, operation);
int event = operation == IOOperation.READ ? EpollAccessor.EPOLLIN : EpollAccessor.EPOLLOUT;
if (epoll.control(fd, id, EpollAccessor.EPOLL_CTL_ADD, event) < 0) {
if (epoll.control(fd, id, EpollAccessor.EPOLL_CTL_ADD, epollOp(operation, context)) < 0) {
LOG.critical().$("internal error: epoll_ctl failure [id=").$(id)
.$(", err=").$(nf.errno()).I$();
}
}
}
private int epollOp(int operation, C context) {
int op = operation == IOOperation.READ ? EpollAccessor.EPOLLIN : EpollAccessor.EPOLLOUT;
if (context.getSocket().wantsTlsRead()) {
op |= EpollAccessor.EPOLLIN;
}
if (context.getSocket().wantsTlsWrite()) {
op |= EpollAccessor.EPOLLOUT;
}
return op;
}
private boolean handleSocketOperation(long id) {
// find row in pending for two reasons:
// 1. find payload
......@@ -108,15 +119,35 @@ public class IODispatcherLinux<C extends IOContext<C>> extends AbstractIODispatc
doDisconnect(context, id, DISCONNECT_SRC_PEER_DISCONNECT);
pending.deleteRow(row);
return true;
} else {
// the connection is alive, so we need to re-arm epoll to be able to detect broken connection
rearmEpoll(context, id, IOOperation.READ);
}
} else {
publishOperation(
// Check EPOLLOUT flag and treat all other events, including EPOLLIN and EPOLLHUP, as a read.
(epoll.getEvent() & EpollAccessor.EPOLLOUT) > 0 ? IOOperation.WRITE : IOOperation.READ,
context
);
pending.deleteRow(row);
return true;
final int requestedOp = (int) pending.get(row, OPM_OPERATION);
// We check EPOLLOUT flag and treat all other events, including EPOLLIN and EPOLLHUP, as a read.
final boolean readyForWrite = (epoll.getEvent() & EpollAccessor.EPOLLOUT) > 0;
final boolean readyForRead = !readyForWrite || (epoll.getEvent() & EpollAccessor.EPOLLIN) > 0;
if ((requestedOp == IOOperation.WRITE && readyForWrite) || (requestedOp == IOOperation.READ && readyForRead)) {
// If the socket is also ready for another operation type, do it.
if (context.getSocket().tlsIO(tlsIOFlags(requestedOp, readyForRead, readyForWrite)) < 0) {
doDisconnect(context, id, DISCONNECT_SRC_TLS_ERROR);
pending.deleteRow(row);
return true;
}
publishOperation(requestedOp, context);
pending.deleteRow(row);
return true;
}
// It's something different from the requested operation.
if (context.getSocket().tlsIO(tlsIOFlags(readyForRead, readyForWrite)) < 0) {
doDisconnect(context, id, DISCONNECT_SRC_TLS_ERROR);
pending.deleteRow(row);
return true;
}
rearmEpoll(context, id, requestedOp);
}
return false;
}
......@@ -288,8 +319,7 @@ public class IODispatcherLinux<C extends IOContext<C>> extends AbstractIODispatc
// we re-arm epoll globally, in that even when we disconnect
// because we have to remove FD from epoll
final int epollOp = operation == IOOperation.READ ? EpollAccessor.EPOLLIN : EpollAccessor.EPOLLOUT;
if (epoll.control(fd, opId, epollCmd, epollOp) < 0) {
if (epoll.control(fd, opId, epollCmd, epollOp(operation, context)) < 0) {
LOG.critical().$("internal error: epoll_ctl modify operation failure [id=").$(opId)
.$(", err=").$(nf.errno()).I$();
}
......@@ -322,19 +352,16 @@ public class IODispatcherLinux<C extends IOContext<C>> extends AbstractIODispatc
pendingEvents.zapTop(count);
}
private void resumeOperation(C context, long id, int operation) {
// to resume a socket operation, we simply re-arm epoll
if (
epoll.control(
context.getFd(),
id,
EpollAccessor.EPOLL_CTL_MOD,
operation == IOOperation.READ ? EpollAccessor.EPOLLIN : EpollAccessor.EPOLLOUT
) < 0
) {
LOG.critical().$("internal error: epoll_ctl operation mod failure [id=").$(id)
private void rearmEpoll(C context, long id, int operation) {
if (epoll.control(context.getFd(), id, EpollAccessor.EPOLL_CTL_MOD, epollOp(operation, context)) < 0) {
LOG.critical().$("internal error: epoll_ctl modify operation failure [id=").$(id)
.$(", err=").$(nf.errno()).I$();
}
}
private void resumeOperation(C context, long id, int operation) {
// to resume a socket operation, we simply re-arm epoll
rearmEpoll(context, id, operation);
context.clearSuspendEvent();
}
......
......@@ -24,6 +24,7 @@
package io.questdb.network;
import io.questdb.std.IntHashSet;
import io.questdb.std.LongMatrix;
public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatcher<C> {
......@@ -31,9 +32,12 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
private static final int EVM_ID = 0;
private static final int EVM_OPERATION_ID = 2;
protected final LongMatrix pendingEvents = new LongMatrix(3);
private final IntHashSet alreadyHandledFds = new IntHashSet();
private final int capacity;
private final KeventWriter keventWriter = new KeventWriter();
private final Kqueue kqueue;
// the final ids are shifted by 1 bit which is reserved to distinguish socket operations (0) and suspend events (1)
// the final ids are shifted by 1 bit which is reserved to distinguish
// socket operations (0) and suspend events (1)
private long idSeq = 1;
public IODispatcherOsx(
......@@ -50,7 +54,7 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
@Override
public void close() {
super.close();
this.kqueue.close();
kqueue.close();
LOG.info().$("closed").$();
}
......@@ -67,9 +71,7 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
if (eventRow < 0) {
LOG.critical().$("internal error: suspend event not found [id=").$(id).I$();
} else {
kqueue.setWriteOffset(0);
kqueue.removeReadFD(suspendEvent.getFd());
registerWithKQueue(1);
keventWriter.prepare().removeReadFD(suspendEvent.getFd()).done();
pendingEvents.deleteRow(eventRow);
}
}
......@@ -77,28 +79,20 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
}
private void enqueuePending(int watermark) {
int index = 0;
for (int i = watermark, sz = pending.size(), offset = 0; i < sz; i++, offset += KqueueAccessor.SIZEOF_KEVENT) {
kqueue.setWriteOffset(offset);
final int fd = (int) pending.get(i, OPM_FD);
long id = pending.get(i, OPM_ID);
keventWriter.prepare();
for (int i = watermark, sz = pending.size(); i < sz; i++) {
final C context = pending.get(i);
final long id = pending.get(i, OPM_ID);
final int operation = initialBias == IODispatcherConfiguration.BIAS_READ ? IOOperation.READ : IOOperation.WRITE;
if (operation == IOOperation.READ) {
kqueue.readFD(fd, id);
} else {
kqueue.writeFD(fd, id);
}
pending.set(i, OPM_OPERATION, operation);
if (++index > capacity - 1) {
registerWithKQueue(index);
index = 0;
offset = 0;
if (operation == IOOperation.READ || context.getSocket().wantsTlsRead()) {
keventWriter.readFD(context.getFd(), id);
}
if (operation == IOOperation.WRITE || context.getSocket().wantsTlsWrite()) {
keventWriter.writeFD(context.getFd(), id);
}
}
if (index > 0) {
registerWithKQueue(index);
}
keventWriter.done();
}
private boolean handleSocketOperation(long id) {
......@@ -119,14 +113,38 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
doDisconnect(context, id, DISCONNECT_SRC_PEER_DISCONNECT);
pending.deleteRow(row);
return true;
} else {
// the connection is alive, so we need to re-arm kqueue to be able to detect broken connection
rearmKqueue(context, id, IOOperation.READ);
}
} else {
publishOperation(
kqueue.getFilter() == KqueueAccessor.EVFILT_READ ? IOOperation.READ : IOOperation.WRITE,
context
);
pending.deleteRow(row);
return true;
final int requestedOp = (int) pending.get(row, OPM_OPERATION);
final boolean readyForWrite = kqueue.getFilter() == KqueueAccessor.EVFILT_WRITE;
final boolean readyForRead = kqueue.getFilter() == KqueueAccessor.EVFILT_READ;
if ((requestedOp == IOOperation.WRITE && readyForWrite) || (requestedOp == IOOperation.READ && readyForRead)) {
// disarm extra filter in case it was previously set and haven't fired yet
keventWriter.prepare().tolerateErrors();
if (requestedOp == IOOperation.READ && context.getSocket().wantsTlsWrite()) {
keventWriter.removeWriteFD(context.getFd());
}
if (requestedOp == IOOperation.WRITE && context.getSocket().wantsTlsRead()) {
keventWriter.removeReadFD(context.getFd());
}
keventWriter.done();
// publish the operation and we're done
publishOperation(requestedOp, context);
pending.deleteRow(row);
return true;
} else {
// that's not the requested operation, but something wanted by the socket
if (context.getSocket().tlsIO(tlsIOFlags(readyForRead, readyForWrite)) < 0) {
doDisconnect(context, id, DISCONNECT_SRC_TLS_ERROR);
pending.deleteRow(row);
return true;
}
rearmKqueue(context, id, requestedOp);
}
}
return false;
}
......@@ -155,14 +173,8 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
.$(", opId=").$(opId)
.$(", eventId=").$(eventId).I$();
rearmKqueue(context, opId, operation);
context.clearSuspendEvent();
kqueue.setWriteOffset(0);
if (operation == IOOperation.READ) {
kqueue.readFD(context.getFd(), opId);
} else {
kqueue.writeFD(context.getFd(), opId);
}
registerWithKQueue(1);
pendingEvents.deleteRow(eventsRow);
}
......@@ -180,17 +192,18 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
for (int i = 0; i < watermark && pending.get(i, OPM_HEARTBEAT_TIMESTAMP) < timestamp; i++, count++) {
final C context = pending.get(i);
// De-register pending operation from epoll. We'll register it later when we get a heartbeat pong.
// De-register pending operation from kqueue. We'll register it later when we get a heartbeat pong.
int fd = context.getFd();
final long opId = pending.get(i, OPM_ID);
kqueue.setWriteOffset(0);
long op = context.getSuspendEvent() != null ? IOOperation.READ : pending.get(i, OPM_OPERATION);
if (op == IOOperation.READ) {
kqueue.removeReadFD(fd);
} else {
kqueue.removeWriteFD(fd);
keventWriter.prepare().tolerateErrors();
if (op == IOOperation.READ || context.getSocket().wantsTlsRead()) {
keventWriter.removeReadFD(fd);
}
if (op == IOOperation.WRITE || context.getSocket().wantsTlsWrite()) {
keventWriter.removeWriteFD(fd);
}
if (kqueue.register(1) != 0) {
if (keventWriter.done() != 0) {
LOG.critical().$("internal error: kqueue remove fd failure [fd=").$(fd)
.$(", err=").$(nf.errno()).I$();
} else {
......@@ -212,15 +225,13 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
final SuspendEvent suspendEvent = context.getSuspendEvent();
if (suspendEvent != null) {
// Also, de-register suspend event from epoll.
// Also, de-register suspend event from kqueue.
int eventRow = pendingEvents.binarySearch(opId, EVM_OPERATION_ID);
if (eventRow < 0) {
LOG.critical().$("internal error: suspend event not found on heartbeat [id=").$(opId).I$();
} else {
final long eventId = pendingEvents.get(eventRow, EVM_ID);
kqueue.setWriteOffset(0);
kqueue.readFD(context.getFd(), eventId);
registerWithKQueue(1);
keventWriter.prepare().readFD(suspendEvent.getFd(), eventId).done();
pendingEvents.deleteRow(eventRow);
}
}
......@@ -240,8 +251,7 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
private boolean processRegistrations(long timestamp) {
long cursor;
boolean useful = false;
int count = 0;
int offset = 0;
keventWriter.prepare();
while ((cursor = interestSubSeq.next()) > -1) {
final IOEvent<C> event = interestQueue.get(cursor);
final C context = event.context;
......@@ -309,38 +319,23 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
pendingEvents.set(eventRow, EVM_OPERATION_ID, opId);
pendingEvents.set(eventRow, EVM_DEADLINE, suspendEvent.getDeadline());
kqueue.setWriteOffset(offset);
kqueue.readFD(suspendEvent.getFd(), eventId);
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++count > capacity - 1) {
registerWithKQueue(count);
count = offset = 0;
}
keventWriter.readFD(suspendEvent.getFd(), eventId);
}
kqueue.setWriteOffset(offset);
if (operation == IOOperation.READ) {
kqueue.readFD(fd, opId);
} else {
kqueue.writeFD(fd, opId);
if (operation == IOOperation.READ || context.getSocket().wantsTlsRead()) {
keventWriter.readFD(fd, opId);
}
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++count > capacity - 1) {
registerWithKQueue(count);
count = offset = 0;
if (operation == IOOperation.WRITE || context.getSocket().wantsTlsWrite()) {
keventWriter.writeFD(fd, opId);
}
}
if (count > 0) {
registerWithKQueue(count);
}
keventWriter.done();
return useful;
}
private void processSuspendEventDeadlines(long timestamp) {
int index = 0;
int offset = 0;
int count = 0;
for (int i = 0, n = pendingEvents.size(); i < n && pendingEvents.get(i, EVM_DEADLINE) < timestamp; i++, count++) {
final long opId = pendingEvents.get(i, EVM_OPERATION_ID);
......@@ -355,39 +350,25 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
final int operation = (int) pending.get(pendingRow, OPM_OPERATION);
final SuspendEvent suspendEvent = context.getSuspendEvent();
assert suspendEvent != null;
kqueue.setWriteOffset(offset);
kqueue.removeReadFD(suspendEvent.getFd());
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++index > capacity - 1) {
registerWithKQueue(index);
index = offset = 0;
}
keventWriter.prepare().removeReadFD(suspendEvent.getFd()).done();
// Next, close the event and resume the original operation.
// to resume a socket operation, we simply re-arm kqueue
rearmKqueue(context, opId, operation);
context.clearSuspendEvent();
kqueue.setWriteOffset(offset);
if (operation == IOOperation.READ) {
kqueue.readFD(context.getFd(), opId);
} else {
kqueue.writeFD(context.getFd(), opId);
}
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++index > capacity - 1) {
registerWithKQueue(index);
index = offset = 0;
}
}
if (index > 0) {
registerWithKQueue(index);
}
pendingEvents.zapTop(count);
}
private void registerWithKQueue(int changeCount) {
if (kqueue.register(changeCount) != 0) {
throw NetworkError.instance(nf.errno()).put("could not register [changeCount=").put(changeCount).put(']');
private void rearmKqueue(C context, long id, int operation) {
keventWriter.prepare();
if (operation == IOOperation.READ || context.getSocket().wantsTlsRead()) {
keventWriter.readFD(context.getFd(), id);
}
if (operation == IOOperation.WRITE || context.getSocket().wantsTlsWrite()) {
keventWriter.writeFD(context.getFd(), id);
}
LOG.debug().$("kqueued [count=").$(changeCount).$(']').$();
keventWriter.done();
}
@Override
......@@ -397,7 +378,7 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
@Override
protected void registerListenerFd() {
if (this.kqueue.listen(serverFd) != 0) {
if (kqueue.listen(serverFd) != 0) {
throw NetworkError.instance(nf.errno(), "could not kqueue.listen()");
}
}
......@@ -408,6 +389,7 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
final long timestamp = clock.getTicks();
processDisconnects(timestamp);
alreadyHandledFds.clear();
final int n = kqueue.poll();
int watermark = pending.size();
int offset = 0;
......@@ -425,10 +407,16 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
useful = true;
continue;
}
if (!alreadyHandledFds.add(fd)) {
// we already handled this fd (socket), but for another event;
// ignore this one as we already removed/re-armed kqueue filter
continue;
}
if (isEventId(id)) {
handleSuspendEvent(id);
continue;
}
// since we may register multiple times
if (handleSocketOperation(id)) {
useful = true;
watermark--;
......@@ -465,8 +453,92 @@ public class IODispatcherOsx<C extends IOContext<C>> extends AbstractIODispatche
@Override
protected void unregisterListenerFd() {
if (this.kqueue.removeListen(serverFd) != 0) {
if (kqueue.removeListen(serverFd) != 0) {
throw NetworkError.instance(nf.errno(), "could not kqueue.removeListen()");
}
}
private class KeventWriter {
private int index;
private int lastError; // 0 means no error
private int offset;
private boolean tolerateErrors;
public int done() {
if (index > 0) {
register(index);
}
index = 0;
offset = 0;
tolerateErrors = false;
int lastError = this.lastError;
this.lastError = 0;
return lastError;
}
public KeventWriter readFD(int fd, long id) {
kqueue.setWriteOffset(offset);
kqueue.readFD(fd, id);
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++index > capacity - 1) {
register(index);
index = offset = 0;
}
return this;
}
public KeventWriter removeReadFD(int fd) {
kqueue.setWriteOffset(offset);
kqueue.removeReadFD(fd);
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++index > capacity - 1) {
register(index);
index = offset = 0;
}
return this;
}
public KeventWriter removeWriteFD(int fd) {
kqueue.setWriteOffset(offset);
kqueue.removeWriteFD(fd);
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++index > capacity - 1) {
register(index);
index = offset = 0;
}
return this;
}
public KeventWriter tolerateErrors() {
this.tolerateErrors = true;
return this;
}
public KeventWriter writeFD(int fd, long id) {
kqueue.setWriteOffset(offset);
kqueue.writeFD(fd, id);
offset += KqueueAccessor.SIZEOF_KEVENT;
if (++index > capacity - 1) {
register(index);
index = offset = 0;
}
return this;
}
private KeventWriter prepare() {
if (index > 0 || offset > 0) {
throw new IllegalStateException("missing done() call");
}
return this;
}
private void register(int changeCount) {
int res = kqueue.register(changeCount);
if (!tolerateErrors && res != 0) {
throw NetworkError.instance(nf.errno()).put("could not register [changeCount=").put(changeCount).put(']');
}
lastError = res != 0 ? res : lastError;
LOG.debug().$("kqueued [count=").$(changeCount).$(']').$();
}
}
}
......@@ -243,10 +243,11 @@ public class IODispatcherWindows<C extends IOContext<C>> extends AbstractIODispa
operation = IOOperation.READ;
}
if (operation == IOOperation.READ) {
if (operation == IOOperation.READ || context.getSocket().wantsTlsRead()) {
readFdSet.add(fd);
readFdCount++;
} else {
}
if (operation == IOOperation.WRITE || context.getSocket().wantsTlsWrite()) {
writeFdSet.add(fd);
writeFdCount++;
}
......@@ -261,24 +262,60 @@ public class IODispatcherWindows<C extends IOContext<C>> extends AbstractIODispa
n--;
watermark--;
} else {
i++; // just skip to the next operation
// the connection is alive, so we need to add it to poll to be able to detect broken connection
readFdSet.add(fd);
readFdCount++;
if (context.getSocket().wantsTlsWrite()) {
writeFdSet.add(fd);
writeFdCount++;
}
i++; // now skip to the next operation
}
continue;
}
// publish event and remove from pending
// we got a (potentially requested) event
useful = true;
if ((newOp & SelectAccessor.FD_READ) > 0) {
publishOperation(IOOperation.READ, context);
}
if ((newOp & SelectAccessor.FD_WRITE) > 0) {
publishOperation(IOOperation.WRITE, context);
}
final int requestedOp = (int) pending.get(i, OPM_OPERATION);
final boolean readyForWrite = (newOp & SelectAccessor.FD_WRITE) > 0;
final boolean readyForRead = (newOp & SelectAccessor.FD_READ) > 0;
pending.deleteRow(i);
n--;
watermark--;
if ((requestedOp == IOOperation.WRITE && readyForWrite) || (requestedOp == IOOperation.READ && readyForRead)) {
// If the socket is also ready for another operation type, do it.
if (context.getSocket().tlsIO(tlsIOFlags(requestedOp, readyForRead, readyForWrite)) < 0) {
doDisconnect(context, DISCONNECT_SRC_TLS_ERROR);
pending.deleteRow(i);
n--;
watermark--;
continue;
}
// publish event and remove from pending
publishOperation(requestedOp, context);
pending.deleteRow(i);
n--;
watermark--;
} else {
// It's something different from the requested operation.
if (context.getSocket().tlsIO(tlsIOFlags(readyForRead, readyForWrite)) < 0) {
doDisconnect(context, DISCONNECT_SRC_TLS_ERROR);
pending.deleteRow(i);
n--;
watermark--;
continue;
}
// Now we need to re-arm poll.
if (requestedOp == IOOperation.READ || context.getSocket().wantsTlsRead()) {
readFdSet.add(fd);
readFdCount++;
}
if (requestedOp == IOOperation.WRITE || context.getSocket().wantsTlsWrite()) {
writeFdSet.add(fd);
writeFdCount++;
}
i++; // now skip to the next operation
}
}
}
......
......@@ -85,18 +85,18 @@ public interface NetworkFacade {
int parseIPv4(CharSequence ipv4Address);
int peek(int fd, long buffer, int bufferLen);
int peekRaw(int fd, long buffer, int bufferLen);
int recv(int fd, long buffer, int bufferLen);
int recvRaw(int fd, long buffer, int bufferLen);
@SuppressWarnings("SpellCheckingInspection")
int recvmmsg(int fd, long msgVec, int msgCount);
int recvmmsgRaw(int fd, long msgVec, int msgCount);
int resolvePort(int fd);
int send(int fd, long buffer, int bufferLen);
int sendRaw(int fd, long buffer, int bufferLen);
int sendTo(int fd, long lo, int len, long socketAddress);
int sendToRaw(int fd, long lo, int len, long socketAddress);
int setMulticastInterface(int fd, CharSequence address);
......
......@@ -27,7 +27,6 @@ package io.questdb.network;
import io.questdb.log.Log;
import io.questdb.std.Files;
import io.questdb.std.Os;
import io.questdb.std.Unsafe;
import io.questdb.std.str.LPSZ;
public class NetworkFacadeImpl implements NetworkFacade {
......@@ -176,17 +175,17 @@ public class NetworkFacadeImpl implements NetworkFacade {
}
@Override
public int peek(int fd, long buffer, int bufferLen) {
public int peekRaw(int fd, long buffer, int bufferLen) {
return Net.peek(fd, buffer, bufferLen);
}
@Override
public int recv(int fd, long buffer, int bufferLen) {
public int recvRaw(int fd, long buffer, int bufferLen) {
return Net.recv(fd, buffer, bufferLen);
}
@Override
public int recvmmsg(int fd, long msgVec, int msgCount) {
public int recvmmsgRaw(int fd, long msgVec, int msgCount) {
return Net.recvmmsg(fd, msgVec, msgCount);
}
......@@ -196,12 +195,12 @@ public class NetworkFacadeImpl implements NetworkFacade {
}
@Override
public int send(int fd, long buffer, int bufferLen) {
public int sendRaw(int fd, long buffer, int bufferLen) {
return Net.send(fd, buffer, bufferLen);
}
@Override
public int sendTo(int fd, long ptr, int len, long socketAddress) {
public int sendToRaw(int fd, long ptr, int len, long socketAddress) {
return Net.sendTo(fd, ptr, len, socketAddress);
}
......@@ -280,30 +279,6 @@ public class NetworkFacadeImpl implements NetworkFacade {
}
final int nRead = Net.peek(fd, buffer, bufferSize);
if (nRead == 0) {
return false;
}
if (nRead < 0) {
return true;
}
// Read \r\n from the input stream and discard it since some HTTP clients
// send these chars as a keep alive in between requests.
int index = 0;
while (index < nRead) {
byte b = Unsafe.getUnsafe().getByte(buffer + index);
if (b != (byte) '\r' && b != (byte) '\n') {
break;
}
index++;
}
if (index > 0) {
Net.recv(fd, buffer, index);
}
return false;
return nRead < 0;
}
}
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2023 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.network;
import io.questdb.log.Log;
public class PlainSocket implements Socket {
private final Log log;
private final NetworkFacade nf;
private int fd = -1;
public PlainSocket(NetworkFacade nf, Log log) {
this.nf = nf;
this.log = log;
}
@Override
public void close() {
if (fd != -1) {
nf.close(fd, log);
fd = -1;
}
}
@Override
public int getFd() {
return fd;
}
@Override
public boolean isTlsSessionStarted() {
return false;
}
@Override
public void of(int fd) {
assert this.fd == -1;
this.fd = fd;
}
@Override
public int recv(long bufferPtr, int bufferLen) {
return nf.recvRaw(fd, bufferPtr, bufferLen);
}
@Override
public int send(long bufferPtr, int bufferLen) {
return nf.sendRaw(fd, bufferPtr, bufferLen);
}
@Override
public int shutdown(int how) {
return nf.shutdown(fd, how);
}
@Override
public int startTlsSession() {
throw new UnsupportedOperationException();
}
@Override
public boolean supportsTls() {
return false;
}
@Override
public int tlsIO(int readinessFlags) {
return 0;
}
@Override
public boolean wantsTlsRead() {
return false;
}
@Override
public boolean wantsTlsWrite() {
return false;
}
}
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2023 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.network;
import io.questdb.log.Log;
public class PlainSocketFactory implements SocketFactory {
public static final SocketFactory INSTANCE = new PlainSocketFactory();
@Override
public Socket newInstance(NetworkFacade nf, Log log) {
return new PlainSocket(nf, log);
}
}
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2023 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.network;
import io.questdb.std.QuietCloseable;
/**
* Abstraction for plain and encrypted TCP sockets. Encrypted sockets use additional buffer
* to accumulate messages, so they require extra calls to convert encrypted data to raw data.
* <p>
* {@link #close()} implementations must be idempotent. Also, supports object reuse after
* {@link #close()}: see {@link #of(int)}.
*/
public interface Socket extends QuietCloseable {
int READ_FLAG = 1 << 1;
int WRITE_FLAG = 1;
/**
* @return file descriptor associated with the socket.
*/
int getFd();
/**
* @return true if TLS session was already started.
*/
boolean isTlsSessionStarted();
/**
* Sets the file descriptor associated with the socket.
* The socket owns the fd after this call.
*
* @param fd file descriptor
*/
void of(int fd);
/**
* Receives plain data into the given buffer from the socket. On encrypted
* sockets this call includes {@link #tlsIO(int)}, so an extra tlsIO()
* call is not required.
*
* @param bufferPtr pointer to the buffer
* @param bufferLen buffer length
* @return recv() result; non-negative if there were no errors.
*/
int recv(long bufferPtr, int bufferLen);
/**
* Sends plain data from the given buffer to the socket. On encrypted
* sockets this call includes {@link #tlsIO(int)}, so an extra tlsIO()
* call is not required.
*
* @param bufferPtr pointer to the buffer
* @param bufferLen buffer length
* @return send() result; non-negative if there were no errors.
*/
int send(long bufferPtr, int bufferLen);
/**
* Does a shutdown() call on the socket.
*
* @param how valid shutdown flag, e.g. {@link Net#SHUT_WR}.
* @return 0 if the call is successful; -1 if there was an error.
*/
int shutdown(int how);
/**
* Starts a TLS session, if supported.
*
* @return 0 if the call is successful; -1 if there was an error.
*/
int startTlsSession();
/**
* @return true if the socket support TLS encryption; false otherwise.
*/
boolean supportsTls();
/**
* Reads or writes encrypted data to/from the internal buffer from/to
* the socket. Can be called safely even if the socket doesn't
* support TLS.
*
* @param readinessFlags socket readiness flags (see {@link #READ_FLAG}
* and {@link #WRITE_FLAG}).
* @return 0 if the call is successful; -1 if there was an error.
*/
int tlsIO(int readinessFlags);
/**
* @return true if a {@link #tlsIO(int)} call should be made once
* the socket becomes readable.
*/
boolean wantsTlsRead();
/**
* @return true if a {@link #tlsIO(int)} call should be made once
* the socket becomes writable.
*/
boolean wantsTlsWrite();
}
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2023 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.network;
import io.questdb.log.Log;
@FunctionalInterface
public interface SocketFactory {
Socket newInstance(NetworkFacade nf, Log log);
}
......@@ -39,7 +39,7 @@ public class WeakMutableObjectPool<T extends Mutable> extends WeakObjectPoolBase
@Override
public void close() {
while (cache.size() > 0) {
while (!cache.isEmpty()) {
Misc.freeIfCloseable(cache.pop());
}
}
......
......@@ -48,7 +48,8 @@ public class GcUtf8String implements DirectUtf8Sequence {
// ***** NOTE *****
// This class causes garbage collection.
// It should be used with care.
// It is currently intended to be used exclusively for the `dirName` and `tableName` fields of `TableToken`.
// It is currently intended to be used for the `dirName` and `tableName` fields
// of `TableToken` and similar things.
this.original = original;
final byte[] bytes = original.getBytes(StandardCharsets.UTF_8);
this.buffer = ByteBuffer.allocateDirect(bytes.length);
......
......@@ -145,8 +145,7 @@ public class WalListenerTest extends AbstractCairoTest {
releaseInactive(engine);
// Empty segment does not generate close event
Assert.assertEquals(0, listener.events.size()
);
Assert.assertEquals(0, listener.events.size());
}
}
......@@ -258,10 +257,10 @@ public class WalListenerTest extends AbstractCairoTest {
}
@Override
public void segmentClosed(TableToken tabletoken, long txn, int walId, int segmentId) {
public void segmentClosed(TableToken tableToken, long txn, int walId, int segmentId) {
events.add(new WalListenerEvent(
WalListenerEventType.SEGMENT_CLOSED,
tabletoken,
tableToken,
txn,
0,
walId,
......@@ -319,8 +318,8 @@ public class WalListenerTest extends AbstractCairoTest {
public final int segmentId;
public final int segmentTxn;
public final TableToken tableToken;
public final long txn;
public final long timestamp;
public final long txn;
public final WalListenerEventType type;
public final int walId;
......
......@@ -26,17 +26,14 @@ package io.questdb.test.client;
import io.questdb.client.Sender;
import io.questdb.cutlass.line.LineSenderException;
import io.questdb.network.NetworkFacadeImpl;
import io.questdb.std.Files;
import io.questdb.test.cutlass.line.tcp.AbstractLineTcpReceiverTest;
import io.questdb.network.NetworkFacadeImpl;
import io.questdb.test.tools.TestUtils;
import io.questdb.test.tools.TlsProxyRule;
import org.junit.ClassRule;
import org.junit.Test;
import java.net.URL;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
public class LineSenderBuilderTest extends AbstractLineTcpReceiverTest {
......@@ -184,7 +181,7 @@ public class LineSenderBuilderTest extends AbstractLineTcpReceiverTest {
authKeyId = AUTH_KEY_ID1;
nf = new NetworkFacadeImpl() {
@Override
public int recv(int fd, long buffer, int bufferLen) {
public int recvRaw(int fd, long buffer, int bufferLen) {
// force server to fail to receive userId and this disconnect
// mid-authentication
return -1;
......@@ -310,9 +307,7 @@ public class LineSenderBuilderTest extends AbstractLineTcpReceiverTest {
@Test
public void testConnectTls_TruststoreFile() throws Exception {
URL trustStoreResource = LineSenderBuilderTest.class.getResource(TRUSTSTORE_PATH);
assertNotNull("Someone accidentally deleted trust store?", trustStoreResource);
String truststore = trustStoreResource.getFile();
String truststore = TestUtils.getTestResourcePath(TRUSTSTORE_PATH);
runInContext(r -> {
try (Sender sender = Sender.builder()
.address(LOCALHOST)
......
......@@ -464,7 +464,6 @@ public class IODispatcherHeartbeatTest {
private static class TestContext extends IOContext<TestContext> {
private final long buffer = Unsafe.malloc(4, MemoryTag.NATIVE_DEFAULT);
private final IODispatcher<TestContext> dispatcher;
private final int fd;
private final long heartbeatInterval;
boolean isPreviousEventHeartbeat = true;
long previousHeartbeatTs;
......@@ -472,7 +471,8 @@ public class IODispatcherHeartbeatTest {
SuspendEvent suspendEvent;
public TestContext(int fd, IODispatcher<TestContext> dispatcher, long heartbeatInterval) {
this.fd = fd;
super(PlainSocketFactory.INSTANCE, NetworkFacadeImpl.INSTANCE, LOG);
socket.of(fd);
this.dispatcher = dispatcher;
this.heartbeatInterval = heartbeatInterval;
}
......@@ -508,7 +508,7 @@ public class IODispatcherHeartbeatTest {
@Override
public void close() {
Unsafe.free(buffer, 4, MemoryTag.NATIVE_DEFAULT);
suspendEvent = Misc.free(suspendEvent);
super.close();
}
@Override
......@@ -516,11 +516,6 @@ public class IODispatcherHeartbeatTest {
return dispatcher;
}
@Override
public int getFd() {
return fd;
}
@Override
public SuspendEvent getSuspendEvent() {
return suspendEvent;
......
......@@ -66,7 +66,7 @@ public class NetUtils {
if (mode == 0) {
// we were sending - lets wrap up and send
if (expectedLen > 0) {
int m = nf.send(clientFd, sendBuf, expectedLen);
int m = nf.sendRaw(clientFd, sendBuf, expectedLen);
// if we expect disconnect we might get it on either `send` or `recv`
// check if we expect disconnect on recv?
if (m == -2 && script.charAt(i + 1) == '!' && script.charAt(i + 2) == '!') {
......@@ -81,7 +81,7 @@ public class NetUtils {
// we meant to receive; sendBuf will contain expected bytes we have to receive
// and this buffer will also drive the length of the message
if (expectedLen > 0) {
int actualLen = nf.recv(clientFd, recvBuf, expectedLen);
int actualLen = nf.recvRaw(clientFd, recvBuf, expectedLen);
if (expectDisconnect) {
Assert.assertTrue(actualLen < 0);
// force exit
......@@ -128,7 +128,7 @@ public class NetUtils {
if (mode == 0) {
// we were sending - lets wrap up and send
if (expectedLen > 0) {
int m = nf.send(clientFd, sendBuf, expectedLen);
int m = nf.sendRaw(clientFd, sendBuf, expectedLen);
Assert.assertEquals(expectedLen, m);
}
} else {
......@@ -138,7 +138,7 @@ public class NetUtils {
if (expectDisconnect) {
Assert.assertTrue(Net.isDead(clientFd));
} else {
int actualLen = nf.recv(clientFd, recvBuf, expectedLen);
int actualLen = nf.recvRaw(clientFd, recvBuf, expectedLen);
assertBuffers(line, sendBuf, expectedLen, recvBuf, actualLen);
}
}
......
......@@ -32,6 +32,7 @@ import io.questdb.cutlass.http.HttpRequestHeader;
import io.questdb.network.NetworkFacadeImpl;
import io.questdb.std.Chars;
import io.questdb.test.AbstractTest;
import org.jetbrains.annotations.NotNull;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
......@@ -627,7 +628,7 @@ public class HttpSecurityTest extends AbstractTest {
) throws Exception {
final FactoryProvider factoryProvider = new DefaultFactoryProvider() {
@Override
public HttpAuthenticatorFactory getHttpAuthenticatorFactory() {
public @NotNull HttpAuthenticatorFactory getHttpAuthenticatorFactory() {
return factory;
}
};
......
......@@ -133,7 +133,6 @@ public class IODispatcherTest extends AbstractTest {
LOG.info().$("started testBiasWrite").$();
assertMemoryLeak(() -> {
SOCountDownLatch connectLatch = new SOCountDownLatch(1);
SOCountDownLatch contextClosedLatch = new SOCountDownLatch(1);
......@@ -2161,10 +2160,10 @@ public class IODispatcherTest extends AbstractTest {
"--------------------------27d997ca93d2689d--",
new NetworkFacadeImpl() {
@Override
public int send(int fd, long buffer, int bufferLen) {
public int sendRaw(int fd, long buffer, int bufferLen) {
// ensure we do not send more than one byte at a time
if (bufferLen > 0) {
return super.send(fd, buffer, 1);
return super.sendRaw(fd, buffer, 1);
}
return 0;
}
......@@ -2268,9 +2267,9 @@ public class IODispatcherTest extends AbstractTest {
int totalSent = 0;
@Override
public int send(int fd, long buffer, int bufferLen) {
public int sendRaw(int fd, long buffer, int bufferLen) {
if (bufferLen > 0) {
int result = super.send(fd, buffer, 1);
int result = super.sendRaw(fd, buffer, 1);
totalSent += result;
// start delaying after 800 bytes
......@@ -3405,7 +3404,7 @@ public class IODispatcherTest extends AbstractTest {
int reqLen = request.length();
Chars.asciiStrCpy(request, reqLen, ptr);
while (sent < reqLen) {
int n = NetworkFacadeImpl.INSTANCE.send(fd, ptr + sent, reqLen - sent);
int n = NetworkFacadeImpl.INSTANCE.sendRaw(fd, ptr + sent, reqLen - sent);
Assert.assertTrue(n > -1);
sent += n;
}
......@@ -3415,7 +3414,7 @@ public class IODispatcherTest extends AbstractTest {
nf.configureNonBlocking(fd);
long t = System.currentTimeMillis();
boolean disconnected = true;
while (nf.recv(fd, ptr, 1) > -1) {
while (nf.recvRaw(fd, ptr, 1) > -1) {
if (t + 20000 < System.currentTimeMillis()) {
disconnected = false;
break;
......@@ -5579,7 +5578,7 @@ public class IODispatcherTest extends AbstractTest {
Chars.asciiStrCpy(request, reqLen, ptr);
boolean disconnected = false;
while (sent < reqLen) {
int n = nf.send(fd, ptr + sent, reqLen - sent);
int n = nf.sendRaw(fd, ptr + sent, reqLen - sent);
if (n < 0) {
disconnected = true;
break;
......@@ -5590,7 +5589,7 @@ public class IODispatcherTest extends AbstractTest {
}
if (!disconnected) {
while (true) {
int n = nf.recv(fd, ptr, len);
int n = nf.recvRaw(fd, ptr, len);
if (n < 0) {
break;
}
......@@ -7719,7 +7718,7 @@ public class IODispatcherTest extends AbstractTest {
requestsReceived.incrementAndGet();
nf.send(context.getFd(), responseBuf, 1);
nf.sendRaw(context.getFd(), responseBuf, 1);
}
};
......@@ -8536,17 +8535,19 @@ public class IODispatcherTest extends AbstractTest {
final long queuedConnectionTimeoutInMs = 250;
class TestIOContext extends IOContext<TestIOContext> {
private final int fd;
private final IntHashSet serverConnectedFds;
private long heartbeatId;
public TestIOContext(int fd, IntHashSet serverConnectedFds) {
this.fd = fd;
super(PlainSocketFactory.INSTANCE, NetworkFacadeImpl.INSTANCE, LOG);
socket.of(fd);
this.serverConnectedFds = serverConnectedFds;
}
@Override
public void close() {
final int fd = getFd();
super.close();
LOG.info().$(fd).$(" disconnected").$();
serverConnectedFds.remove(fd);
}
......@@ -8556,14 +8557,9 @@ public class IODispatcherTest extends AbstractTest {
return heartbeatId;
}
@Override
public int getFd() {
return fd;
}
@Override
public boolean invalid() {
return !serverConnectedFds.contains(fd);
return !serverConnectedFds.contains(getFd());
}
@Override
......@@ -8801,7 +8797,8 @@ public class IODispatcherTest extends AbstractTest {
private final SOCountDownLatch closeLatch;
public HelloContext(int fd, SOCountDownLatch closeLatch, IODispatcher<HelloContext> dispatcher) {
this.fd = fd;
super(PlainSocketFactory.INSTANCE, NetworkFacadeImpl.INSTANCE, LOG);
socket.of(fd);
this.closeLatch = closeLatch;
this.dispatcher = dispatcher;
}
......@@ -8810,6 +8807,7 @@ public class IODispatcherTest extends AbstractTest {
public void close() {
Unsafe.free(buffer, 1024, MemoryTag.NATIVE_DEFAULT);
closeLatch.countDown();
super.close();
}
@Override
......
......@@ -27,9 +27,7 @@ package io.questdb.test.cutlass.http;
import io.questdb.cutlass.http.HttpException;
import io.questdb.cutlass.http.MimeTypesCache;
import io.questdb.std.Chars;
import io.questdb.std.Files;
import io.questdb.std.FilesFacade;
import io.questdb.std.Os;
import io.questdb.std.str.LPSZ;
import io.questdb.std.str.Path;
import io.questdb.test.AbstractTest;
......@@ -128,19 +126,11 @@ public class MimeTypesCacheTest extends AbstractTest {
@Test
public void testSimple() throws Exception {
TestUtils.assertMemoryLeak(new TestUtils.LeakProneCode() {
@Override
public void run() {
try (Path path = new Path()) {
String filePath;
if (Os.isWindows()) {
filePath = Files.getResourcePath(getClass().getResource("/mime.types")).substring(1);
} else {
filePath = Files.getResourcePath(getClass().getResource("/mime.types"));
}
path.of(filePath).$();
assertMimeTypes(new MimeTypesCache(TestFilesFacadeImpl.INSTANCE, path));
}
TestUtils.assertMemoryLeak(() -> {
try (Path path = new Path()) {
String filePath = TestUtils.getTestResourcePath("/mime.types");
path.of(filePath).$();
assertMimeTypes(new MimeTypesCache(TestFilesFacadeImpl.INSTANCE, path));
}
});
}
......
......@@ -127,7 +127,7 @@ public class SendAndReceiveRequestBuilder {
int reqLen = request.length();
Chars.asciiStrCpy(request, reqLen, ptr);
while (sent < reqLen) {
int n = nf.send(fd, ptr + sent, reqLen - sent);
int n = nf.sendRaw(fd, ptr + sent, reqLen - sent);
if (n < 0 && expectSendDisconnect) {
return;
}
......@@ -150,7 +150,7 @@ public class SendAndReceiveRequestBuilder {
boolean timeoutExpired = false;
IntList receivedByteList = new IntList(expectedToReceive);
while (received < expectedToReceive || expectReceiveDisconnect) {
int n = nf.recv(fd, ptr + received, len - received);
int n = nf.recvRaw(fd, ptr + received, len - received);
if (n > 0) {
for (int i = 0; i < n; i++) {
receivedByteList.add(Unsafe.getUnsafe().getByte(ptr + received + i) & 0xff);
......@@ -256,7 +256,7 @@ public class SendAndReceiveRequestBuilder {
int reqLen = request.length();
Chars.asciiStrCpy(request, reqLen, ptr);
while (sent < reqLen) {
int n = nf.send(fd, ptr + sent, reqLen - sent);
int n = nf.sendRaw(fd, ptr + sent, reqLen - sent);
Assert.assertTrue(n > -1);
sent += n;
}
......@@ -269,7 +269,7 @@ public class SendAndReceiveRequestBuilder {
int received = 0;
IntList receivedByteList = new IntList();
while (true) {
int n = nf.recv(fd, ptr + received, len - received);
int n = nf.recvRaw(fd, ptr + received, len - received);
if (n > 0) {
for (int i = 0; i < n; i++) {
receivedByteList.add(Unsafe.getUnsafe().getByte(ptr + received + i));
......
......@@ -31,30 +31,22 @@ import io.questdb.cutlass.json.JsonLexer;
import io.questdb.cutlass.json.JsonParser;
import io.questdb.cutlass.line.LineSenderException;
import io.questdb.cutlass.line.LineTcpSender;
import io.questdb.test.cutlass.line.tcp.StringChannel;
import io.questdb.griffin.SqlKeywords;
import io.questdb.std.*;
import io.questdb.std.str.Path;
import io.questdb.std.str.StringSink;
import io.questdb.test.cutlass.line.tcp.StringChannel;
import io.questdb.test.std.TestFilesFacadeImpl;
import io.questdb.test.tools.TestUtils;
import org.junit.Assert;
import org.junit.Test;
import java.net.URL;
public class TestInterop {
@Test
public void testInterop() throws Exception {
FilesFacade ff = TestFilesFacadeImpl.INSTANCE;
URL testCasesUrl = TestInterop.class.getResource("/io/questdb/test/cutlass/line/interop/ilp-client-interop-test.json");
Assert.assertNotNull("interop test cases missing", testCasesUrl);
String pp = testCasesUrl.getFile();
if (Os.isWindows()) {
// on Windows Java returns "/C:/dir/file". This leading slash is Java specific and doesn't bode well
// with OS file open methods.
pp = pp.substring(1);
}
String pp = TestUtils.getTestResourcePath("/io/questdb/test/cutlass/line/interop/ilp-client-interop-test.json");
StringChannel channel = new StringChannel();
try (JsonLexer lexer = new JsonLexer(1024, 1024);
......
......@@ -44,6 +44,7 @@ import io.questdb.std.str.Path;
import io.questdb.test.AbstractCairoTest;
import io.questdb.test.mp.TestWorkerPool;
import io.questdb.test.tools.TestUtils;
import org.jetbrains.annotations.NotNull;
import org.junit.After;
import org.junit.Assert;
......@@ -73,20 +74,20 @@ public class AbstractLineTcpReceiverTest extends AbstractCairoTest {
private final static Log LOG = LogFactory.getLog(AbstractLineTcpReceiverTest.class);
protected final int bindPort = 9002; // Don't clash with other tests since they may run in parallel
protected final WorkerPool sharedWorkerPool = new TestWorkerPool(getWorkerCount(), metrics);
private final IODispatcherConfiguration ioDispatcherConfiguration = new DefaultIODispatcherConfiguration() {
@Override
public int getBindPort() {
return bindPort;
}
private final ThreadLocal<Socket> tlSocket = new ThreadLocal<>();
protected String authKeyId = null;
private final FactoryProvider factoryProvider = new DefaultFactoryProvider() {
@Override
public long getHeartbeatInterval() {
return 15;
public @NotNull LineAuthenticatorFactory getLineAuthenticatorFactory() {
if (authKeyId == null) {
return super.getLineAuthenticatorFactory();
}
URL u = getClass().getResource("authDb.txt");
assert u != null;
CharSequenceObjHashMap<PublicKey> authDb = AuthUtils.loadAuthDb(u.getFile());
return new EllipticCurveAuthenticatorFactory(() -> new StaticChallengeResponseMatcher(authDb));
}
};
private final ThreadLocal<Socket> tlSocket = new ThreadLocal<>();
protected String authKeyId = null;
protected boolean autoCreateNewColumns = true;
protected long commitIntervalDefault = 2000;
protected double commitIntervalFraction = 0.5;
......@@ -96,16 +97,20 @@ public class AbstractLineTcpReceiverTest extends AbstractCairoTest {
protected long minIdleMsBeforeWriterRelease = 30000;
protected int msgBufferSize = 256 * 1024;
protected NetworkFacade nf = NetworkFacadeImpl.INSTANCE;
private final FactoryProvider factoryProvider = new DefaultFactoryProvider() {
private final IODispatcherConfiguration ioDispatcherConfiguration = new DefaultIODispatcherConfiguration() {
@Override
public LineAuthenticatorFactory getLineAuthenticatorFactory() {
if (authKeyId == null) {
return super.getLineAuthenticatorFactory();
}
URL u = getClass().getResource("authDb.txt");
assert u != null;
CharSequenceObjHashMap<PublicKey> authDb = AuthUtils.loadAuthDb(u.getFile());
return new EllipticCurveAuthenticatorFactory(nf, () -> new StaticChallengeResponseMatcher(authDb));
public int getBindPort() {
return bindPort;
}
@Override
public long getHeartbeatInterval() {
return 15;
}
@Override
public NetworkFacade getNetworkFacade() {
return nf;
}
};
protected int partitionByDefault = PartitionBy.DAY;
......
......@@ -120,7 +120,7 @@ abstract class BaseLineTcpContextTest extends AbstractCairoTest {
}
protected void closeContext() {
if (null != scheduler) {
if (scheduler != null) {
workerPool.halt();
Assert.assertFalse(context.invalid());
Assert.assertEquals(FD, context.getFd());
......@@ -140,12 +140,12 @@ abstract class BaseLineTcpContextTest extends AbstractCairoTest {
protected LineTcpReceiverConfiguration createReceiverConfiguration(final boolean withAuth, final NetworkFacade nf) {
final FactoryProvider factoryProvider = new DefaultFactoryProvider() {
@Override
public LineAuthenticatorFactory getLineAuthenticatorFactory() {
public @NotNull LineAuthenticatorFactory getLineAuthenticatorFactory() {
if (withAuth) {
URL u = getClass().getResource("authDb.txt");
assert u != null;
CharSequenceObjHashMap<PublicKey> authDb = AuthUtils.loadAuthDb(u.getFile());
return new EllipticCurveAuthenticatorFactory(nf, () -> new StaticChallengeResponseMatcher(authDb));
return new EllipticCurveAuthenticatorFactory(() -> new StaticChallengeResponseMatcher(authDb));
}
return super.getLineAuthenticatorFactory();
}
......@@ -265,11 +265,6 @@ abstract class BaseLineTcpContextTest extends AbstractCairoTest {
});
}
@FunctionalInterface
public interface UnstableRunnable {
void run() throws Exception;
}
protected void runInContext(UnstableRunnable r) throws Exception {
runInContext(r, null);
}
......@@ -382,6 +377,11 @@ abstract class BaseLineTcpContextTest extends AbstractCairoTest {
Os.sleep(lineTcpConfiguration.getMaintenanceInterval() + 50);
}
@FunctionalInterface
public interface UnstableRunnable {
void run() throws Exception;
}
static class NoNetworkIOJob implements NetworkIOJob {
private final ByteCharSequenceObjHashMap<TableUpdateDetails> localTableUpdateDetailsByTableName = new ByteCharSequenceObjHashMap<>();
private final ObjList<SymbolCache> unusedSymbolCaches = new ObjList<>();
......@@ -440,7 +440,12 @@ abstract class BaseLineTcpContextTest extends AbstractCairoTest {
class LineTcpNetworkFacade extends NetworkFacadeImpl {
@Override
public int recv(int fd, long buffer, int bufferLen) {
public void close(int fd, Log log) {
Assert.assertEquals(FD, fd);
}
@Override
public int recvRaw(int fd, long buffer, int bufferLen) {
Assert.assertEquals(FD, fd);
if (recvBuffer == null) {
return -1;
......
......@@ -63,7 +63,7 @@ public class EllipticCurveAuthConnectionContextTest extends BaseLineTcpContextTe
integerDefaultColumnType = ColumnType.LONG;
lineTcpConfiguration = createReceiverConfiguration(true, new LineTcpNetworkFacade() {
@Override
public int send(int fd, long buffer, int bufferLen) {
public int sendRaw(int fd, long buffer, int bufferLen) {
Assert.assertEquals(FD, fd);
if (null != sentBytes) {
return 0;
......@@ -305,7 +305,8 @@ public class EllipticCurveAuthConnectionContextTest extends BaseLineTcpContextTe
false,
false,
true,
null);
null
);
Assert.assertTrue(authSequenceCompleted);
} catch (RuntimeException ex) {
// Expected that Java 8 does not have SHA256withECDSAinP1363
......@@ -331,7 +332,8 @@ public class EllipticCurveAuthConnectionContextTest extends BaseLineTcpContextTe
runInAuthContext(() -> {
try {
boolean authSequenceCompleted = authenticate(AUTH_KEY_ID1, AUTH_PRIVATE_KEY1,
"weather,location=us-midwest temperature=82 1465839830100400200\n");
"weather,location=us-midwest temperature=82 1465839830100400200\n"
);
Assert.assertTrue(authSequenceCompleted);
} catch (RuntimeException ex) {
// Expected that Java 8 does not have SHA256withECDSAinP1363
......@@ -421,25 +423,30 @@ public class EllipticCurveAuthConnectionContextTest extends BaseLineTcpContextTe
false,
false,
null,
extraData);
extraData
);
}
private boolean authenticate(boolean fragmentKeyId,
boolean fragmentChallenge,
boolean fragmentSignature,
boolean useP1363Encoding,
byte[] junkSignature) {
private boolean authenticate(
boolean fragmentKeyId,
boolean fragmentChallenge,
boolean fragmentSignature,
boolean useP1363Encoding,
byte[] junkSignature
) {
return authenticate(AbstractLineTcpReceiverTest.AUTH_KEY_ID1, AbstractLineTcpReceiverTest.AUTH_PRIVATE_KEY1, fragmentKeyId, fragmentChallenge, fragmentSignature, useP1363Encoding, junkSignature, "");
}
private boolean authenticate(String authKeyId,
PrivateKey authPrivateKey,
boolean fragmentKeyId,
boolean fragmentChallenge,
boolean fragmentSignature,
boolean useP1363Encoding,
byte[] junkSignature,
String extraData) {
private boolean authenticate(
String authKeyId,
PrivateKey authPrivateKey,
boolean fragmentKeyId,
boolean fragmentChallenge,
boolean fragmentSignature,
boolean useP1363Encoding,
byte[] junkSignature,
String extraData
) {
send(authKeyId + "\n", fragmentKeyId);
byte[] challengeBytes = readChallenge(fragmentChallenge);
if (null == challengeBytes) {
......
......@@ -109,55 +109,6 @@ public abstract class BasePGTest extends AbstractCairoTest {
);
}
private static void toSink(InputStream is, CharSink sink) throws IOException {
// limit what we print
byte[] bb = new byte[1];
int i = 0;
while (is.read(bb) > 0) {
byte b = bb[0];
if (i > 0) {
if ((i % 16) == 0) {
sink.put('\n');
Numbers.appendHexPadded(sink, i);
}
} else {
Numbers.appendHexPadded(sink, i);
}
sink.put(' ');
final int v;
if (b < 0) {
v = 256 + b;
} else {
v = b;
}
if (v < 0x10) {
sink.put('0');
sink.put(hexDigits[b]);
} else {
sink.put(hexDigits[v / 0x10]);
sink.put(hexDigits[v % 0x10]);
}
i++;
}
}
protected static void assertResultSet(CharSequence expected, StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
assertResultSet(null, expected, sink, rs, map);
}
protected static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
printToSink(sink, rs, map);
TestUtils.assertEquals(message, expected, sink);
}
protected static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs) throws SQLException, IOException {
printToSink(sink, rs, null);
TestUtils.assertEquals(message, expected, sink);
}
public static long printToSink(StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
// dump metadata
ResultSetMetaData metaData = rs.getMetaData();
......@@ -290,6 +241,55 @@ public abstract class BasePGTest extends AbstractCairoTest {
return rows;
}
private static void toSink(InputStream is, CharSink sink) throws IOException {
// limit what we print
byte[] bb = new byte[1];
int i = 0;
while (is.read(bb) > 0) {
byte b = bb[0];
if (i > 0) {
if ((i % 16) == 0) {
sink.put('\n');
Numbers.appendHexPadded(sink, i);
}
} else {
Numbers.appendHexPadded(sink, i);
}
sink.put(' ');
final int v;
if (b < 0) {
v = 256 + b;
} else {
v = b;
}
if (v < 0x10) {
sink.put('0');
sink.put(hexDigits[b]);
} else {
sink.put(hexDigits[v / 0x10]);
sink.put(hexDigits[v % 0x10]);
}
i++;
}
}
protected static void assertResultSet(CharSequence expected, StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
assertResultSet(null, expected, sink, rs, map);
}
protected static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs, @Nullable IntIntHashMap map) throws SQLException, IOException {
printToSink(sink, rs, map);
TestUtils.assertEquals(message, expected, sink);
}
protected static void assertResultSet(String message, CharSequence expected, StringSink sink, ResultSet rs) throws SQLException, IOException {
printToSink(sink, rs, null);
TestUtils.assertEquals(message, expected, sink);
}
protected PGWireServer createPGServer(PGWireConfiguration configuration) throws SqlException {
TestWorkerPool workerPool = new TestWorkerPool(configuration.getWorkerCount(), metrics);
copyRequestJob = new CopyRequestJob(engine, configuration.getWorkerCount());
......@@ -423,10 +423,10 @@ public abstract class BasePGTest extends AbstractCairoTest {
protected NetworkFacade getFragmentedSendFacade() {
return new NetworkFacadeImpl() {
@Override
public int send(int fd, long buffer, int bufferLen) {
public int sendRaw(int fd, long buffer, int bufferLen) {
int total = 0;
for (int i = 0; i < bufferLen; i++) {
int n = super.send(fd, buffer + i, 1);
int n = super.sendRaw(fd, buffer + i, 1);
if (n < 0) {
return n;
}
......
......@@ -33,6 +33,7 @@ import io.questdb.cutlass.pgwire.ReadOnlyUsersAwareSecurityContextFactory;
import io.questdb.mp.WorkerPool;
import io.questdb.std.Os;
import io.questdb.test.tools.TestUtils;
import org.jetbrains.annotations.NotNull;
import org.junit.*;
import org.postgresql.PGProperty;
import org.postgresql.util.PSQLException;
......@@ -51,7 +52,7 @@ public class PGSecurityTest extends BasePGTest {
private static final SecurityContextFactory READ_ONLY_SECURITY_CONTEXT_FACTORY = new ReadOnlyUsersAwareSecurityContextFactory(true, null, false);
private static final FactoryProvider READ_ONLY_FACTORY_PROVIDER = new DefaultFactoryProvider() {
@Override
public SecurityContextFactory getSecurityContextFactory() {
public @NotNull SecurityContextFactory getSecurityContextFactory() {
return READ_ONLY_SECURITY_CONTEXT_FACTORY;
}
};
......@@ -64,7 +65,7 @@ public class PGSecurityTest extends BasePGTest {
private static final SecurityContextFactory READ_ONLY_USER_SECURITY_CONTEXT_FACTORY = new ReadOnlyUsersAwareSecurityContextFactory(false, "user", false);
private static final FactoryProvider READ_ONLY_USER_FACTORY_PROVIDER = new DefaultFactoryProvider() {
@Override
public SecurityContextFactory getSecurityContextFactory() {
public @NotNull SecurityContextFactory getSecurityContextFactory() {
return READ_ONLY_USER_SECURITY_CONTEXT_FACTORY;
}
};
......
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2023 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.test.cutlass.pgwire;
import io.questdb.DefaultFactoryProvider;
import io.questdb.FactoryProvider;
import io.questdb.cutlass.pgwire.PGWireConfiguration;
import io.questdb.cutlass.pgwire.PGWireServer;
import io.questdb.log.Log;
import io.questdb.mp.WorkerPool;
import io.questdb.network.NetworkFacade;
import io.questdb.network.PlainSocket;
import io.questdb.network.Socket;
import io.questdb.network.SocketFactory;
import io.questdb.test.mp.TestWorkerPool;
import org.jetbrains.annotations.NotNull;
import org.junit.Assert;
import org.junit.Test;
import org.postgresql.util.PSQLException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
public class PGTlsCompatTest extends BasePGTest {
@Test
public void testTlsSessionGetsCreatedWhenSocketSupportsTls() throws Exception {
assertMemoryLeak(() -> {
final AtomicInteger createTlsSessionCalls = new AtomicInteger();
final AtomicInteger tlsIOCalls = new AtomicInteger();
final PGWireConfiguration conf = new Port0PGWireConfiguration() {
@Override
public FactoryProvider getFactoryProvider() {
return new DefaultFactoryProvider() {
@Override
public @NotNull SocketFactory getPGWireSocketFactory() {
return new FakeTlsSocketFactory(true, createTlsSessionCalls, tlsIOCalls);
}
};
}
};
final WorkerPool workerPool = new TestWorkerPool(1, metrics);
try (final PGWireServer server = createPGWireServer(conf, engine, workerPool)) {
Assert.assertNotNull(server);
workerPool.start(LOG);
Properties properties = newPGProperties();
properties.setProperty("sslmode", "require");
final String url = String.format("jdbc:postgresql://127.0.0.1:%d/qdb", server.getPort());
try (Connection ignore = DriverManager.getConnection(url, properties)) {
Assert.fail();
} catch (PSQLException ignore) {
} finally {
workerPool.halt();
}
Assert.assertTrue("Some create TLS session calls expected: " + createTlsSessionCalls.get(), createTlsSessionCalls.get() > 0);
Assert.assertTrue("Some TLS I/O calls expected: " + tlsIOCalls.get(), tlsIOCalls.get() > 0);
}
});
}
@Test
public void testTlsSessionIsNotCreatedWhenSocketDoesNotSupportTls() throws Exception {
assertMemoryLeak(() -> {
final AtomicInteger createTlsSessionCalls = new AtomicInteger();
final AtomicInteger tlsIOCalls = new AtomicInteger();
final PGWireConfiguration conf = new Port0PGWireConfiguration() {
@Override
public FactoryProvider getFactoryProvider() {
return new DefaultFactoryProvider() {
@Override
public @NotNull SocketFactory getPGWireSocketFactory() {
return new FakeTlsSocketFactory(false, createTlsSessionCalls, tlsIOCalls);
}
};
}
};
final WorkerPool workerPool = new TestWorkerPool(1, metrics);
try (final PGWireServer server = createPGWireServer(conf, engine, workerPool)) {
Assert.assertNotNull(server);
workerPool.start(LOG);
Properties properties = newPGProperties();
final String url = String.format("jdbc:postgresql://127.0.0.1:%d/qdb", server.getPort());
try (Connection ignore = DriverManager.getConnection(url, properties)) {
Assert.fail();
} catch (PSQLException ignore) {
} finally {
workerPool.halt();
}
Assert.assertEquals("No create TLS session calls expected: " + createTlsSessionCalls.get(), 0, createTlsSessionCalls.get());
Assert.assertEquals("No TLS I/O calls expected: " + tlsIOCalls.get(), 0, tlsIOCalls.get());
}
});
}
@NotNull
private static Properties newPGProperties() {
Properties properties = new Properties();
properties.setProperty("user", "admin");
properties.setProperty("password", "quest");
properties.setProperty("sslmode", "require");
properties.setProperty("binaryTransfer", "true");
return properties;
}
private static class FakeTlsSocket extends PlainSocket {
private final AtomicInteger createTlsSessionCalls;
private final AtomicInteger tlsIOCalls;
private final boolean tlsSupported;
private boolean tlsSessionStarted;
public FakeTlsSocket(NetworkFacade nf, Log log, boolean tlsSupported, AtomicInteger createTlsSessionCalls, AtomicInteger tlsIOCalls) {
super(nf, log);
this.tlsSupported = tlsSupported;
this.createTlsSessionCalls = createTlsSessionCalls;
this.tlsIOCalls = tlsIOCalls;
}
@Override
public void close() {
super.close();
tlsSessionStarted = false;
}
@Override
public boolean isTlsSessionStarted() {
return tlsSessionStarted;
}
@Override
public int startTlsSession() {
if (!tlsSessionStarted) {
createTlsSessionCalls.incrementAndGet();
tlsSessionStarted = true;
return 0;
}
return -1;
}
@Override
public boolean supportsTls() {
return tlsSupported;
}
@Override
public int tlsIO(int readinessFlags) {
if (tlsSessionStarted) {
tlsIOCalls.incrementAndGet();
return -1; // return error code to force close the connection
}
return 0;
}
@Override
public boolean wantsTlsRead() {
return tlsSessionStarted;
}
@Override
public boolean wantsTlsWrite() {
return tlsSessionStarted;
}
}
private static class FakeTlsSocketFactory implements SocketFactory {
private final AtomicInteger createTlsSessionCalls;
private final AtomicInteger tlsIOCalls;
private final boolean tlsSupported;
private FakeTlsSocketFactory(boolean tlsSupported, AtomicInteger createTlsSessionCalls, AtomicInteger tlsIOCalls) {
this.tlsSupported = tlsSupported;
this.createTlsSessionCalls = createTlsSessionCalls;
this.tlsIOCalls = tlsIOCalls;
}
@Override
public Socket newInstance(NetworkFacade nf, Log log) {
return new FakeTlsSocket(nf, log, tlsSupported, createTlsSessionCalls, tlsIOCalls);
}
}
}
......@@ -51,8 +51,6 @@ import org.junit.Rule;
import org.junit.rules.Timeout;
import java.io.File;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.concurrent.TimeUnit;
public class AbstractO3Test extends AbstractTest {
......@@ -243,11 +241,9 @@ public class AbstractO3Test extends AbstractTest {
SqlExecutionContext sqlExecutionContext,
String sql,
String resourceName
) throws URISyntaxException, SqlException {
) throws SqlException {
AbstractO3Test.printSqlResult(compiler, sqlExecutionContext, sql);
URL url = O3Test.class.getResource(resourceName);
Assert.assertNotNull(url);
TestUtils.assertEquals(new File(url.toURI()), sink);
TestUtils.assertEquals(new File(TestUtils.getTestResourcePath(resourceName)), sink);
}
static void assertXCount(SqlCompiler compiler, SqlExecutionContext sqlExecutionContext) throws SqlException {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册