diff --git a/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java b/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java index 0a6f0b45e9de20f282880b854ad209ec97f31cf3..18966d1766ca5f571b15c55d9427b2dbf9176a11 100644 --- a/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java +++ b/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java @@ -73,11 +73,14 @@ import org.apache.rocketmq.common.protocol.body.TopicConfigSerializeWrapper; import org.apache.rocketmq.common.stats.MomentStatsItem; import org.apache.rocketmq.remoting.RPCHook; import org.apache.rocketmq.remoting.RemotingServer; +import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.netty.NettyClientConfig; import org.apache.rocketmq.remoting.netty.NettyRemotingServer; import org.apache.rocketmq.remoting.netty.NettyRequestProcessor; import org.apache.rocketmq.remoting.netty.NettyServerConfig; import org.apache.rocketmq.remoting.netty.RequestTask; +import org.apache.rocketmq.remoting.netty.TlsSystemConfig; +import org.apache.rocketmq.srvutil.FileWatchService; import org.apache.rocketmq.store.DefaultMessageStore; import org.apache.rocketmq.store.MessageArrivingListener; import org.apache.rocketmq.store.MessageStore; @@ -136,6 +139,7 @@ public class BrokerController { private InetSocketAddress storeHost; private BrokerFastFailure brokerFastFailure; private Configuration configuration; + private FileWatchService fileWatchService; public BrokerController( final BrokerConfig brokerConfig, @@ -387,6 +391,23 @@ public class BrokerController { } }, 1000 * 10, 1000 * 60, TimeUnit.MILLISECONDS); } + + if (TlsSystemConfig.tlsMode != TlsMode.DISABLED) { + // Register a listener to reload SslContext + try { + fileWatchService = new FileWatchService( + new String[] {TlsSystemConfig.tlsServerCertPath, TlsSystemConfig.tlsServerKeyPath}, + new FileWatchService.Listener() { + @Override + public void onChanged() { + ((NettyRemotingServer) remotingServer).loadSslContext(); + ((NettyRemotingServer) fastRemotingServer).loadSslContext(); + } + }); + } catch (IOException e) { + log.warn("FileWatchService created error, can't load the certificate dynamically"); + } + } } return result; @@ -594,6 +615,10 @@ public class BrokerController { this.fastRemotingServer.shutdown(); } + if (this.fileWatchService != null) { + this.fileWatchService.shutdown(); + } + if (this.messageStore != null) { this.messageStore.shutdown(); } @@ -662,6 +687,10 @@ public class BrokerController { this.fastRemotingServer.start(); } + if (this.fileWatchService != null) { + this.fileWatchService.start(); + } + if (this.brokerOuterAPI != null) { this.brokerOuterAPI.start(); } diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java index 76752529af24f021c6860fb90866f629d4a8a75b..557ad5602b0bf961af7c40dbba6613a9ecba34ff 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java @@ -93,7 +93,7 @@ public abstract class NettyRemotingAbstract { /** * SSL context via which to create {@link SslHandler}. */ - protected SslContext sslContext; + protected volatile SslContext sslContext; /** * Constructor, specifying capacity of one-way and asynchronous semaphores. diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java index cd6ed4704f48a36413f425a4f54b2561ba8ef01d..c8709a501498d50bb408018a3f40848d169e2deb 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java @@ -139,6 +139,10 @@ public class NettyRemotingServer extends NettyRemotingAbstract implements Remoti }); } + loadSslContext(); + } + + public void loadSslContext() { TlsMode tlsMode = TlsSystemConfig.tlsMode; log.info("Server is running in TLS {} mode", tlsMode.getName()); diff --git a/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java b/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java index 5e516dd74bda5488afc01197eac0816a055983e7..13bb17282e30ec689f495fa4920eb555bae16fbc 100644 --- a/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java +++ b/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java @@ -25,6 +25,7 @@ import java.io.PrintWriter; import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.exception.RemotingSendRequestException; import org.apache.rocketmq.remoting.netty.NettyClientConfig; +import org.apache.rocketmq.remoting.netty.NettyRemotingServer; import org.apache.rocketmq.remoting.netty.TlsHelper; import org.apache.rocketmq.remoting.protocol.LanguageCode; import org.apache.rocketmq.remoting.protocol.RemotingCommand; @@ -134,6 +135,9 @@ public class TlsTest { clientConfig.setUseTLS(false); } else if ("serverRejectsSSLClient".equals(name.getMethodName())) { tlsMode = TlsMode.DISABLED; + } else if ("reloadSslContextForServer".equals(name.getMethodName())) { + tlsClientAuthServer = false; + tlsServerNeedClientAuth = "none"; } remotingServer = RemotingServerTest.createRemotingServer(); @@ -156,6 +160,26 @@ public class TlsTest { requestThenAssertResponse(); } + @Test + public void reloadSslContextForServer() throws Exception { + requestThenAssertResponse(); + + //Use new cert and private key + tlsClientKeyPath = getCertsPath("badClient.key"); + tlsClientCertPath = getCertsPath("badClient.pem"); + + ((NettyRemotingServer) remotingServer).loadSslContext(); + + //Request Again + requestThenAssertResponse(); + + //Start another client + NettyClientConfig clientConfig = new NettyClientConfig(); + clientConfig.setUseTLS(true); + RemotingClient remotingClient = RemotingServerTest.createRemotingClient(clientConfig); + requestThenAssertResponse(remotingClient); + } + @Test public void serverNotNeedClientAuth() throws Exception { requestThenAssertResponse(); @@ -289,6 +313,10 @@ public class TlsTest { } private void requestThenAssertResponse() throws Exception { + requestThenAssertResponse(remotingClient); + } + + private void requestThenAssertResponse(RemotingClient remotingClient) throws Exception { RemotingCommand response = remotingClient.invokeSync("localhost:8888", createRequest(), 1000 * 3); assertTrue(response != null); assertThat(response.getLanguage()).isEqualTo(LanguageCode.JAVA); diff --git a/srvutil/src/main/java/org/apache/rocketmq/srvutil/FileWatchService.java b/srvutil/src/main/java/org/apache/rocketmq/srvutil/FileWatchService.java new file mode 100644 index 0000000000000000000000000000000000000000..b32e240b5109aeb0201ecd41e1517c8a76b2aa46 --- /dev/null +++ b/srvutil/src/main/java/org/apache/rocketmq/srvutil/FileWatchService.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.srvutil; + +import com.google.common.hash.HashCode; +import com.google.common.hash.Hashing; +import com.google.common.io.Files; +import java.io.File; +import java.io.IOException; +import org.apache.rocketmq.common.ServiceThread; +import org.apache.rocketmq.common.constant.LoggerName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FileWatchService extends ServiceThread { + private static final Logger log = LoggerFactory.getLogger(LoggerName.COMMON_LOGGER_NAME); + + private String [] watchFiles; + private boolean [] isFileChangedFlag; + private HashCode [] fileCurrentHash; + private Listener listener; + private static final int WATCH_INTERVAL = 100; + + + public FileWatchService(final String [] watchFiles, + final Listener listener) throws IOException { + this.watchFiles = watchFiles; + this.listener = listener; + this.isFileChangedFlag = new boolean[watchFiles.length]; + this.fileCurrentHash = new HashCode[watchFiles.length]; + + for (int i = 0; i < watchFiles.length; i++) { + isFileChangedFlag[i] = false; + fileCurrentHash[i] = Files.hash(new File(watchFiles[i]), Hashing.md5()); + } + } + + @Override + public String getServiceName() { + return "FileWatchService"; + } + + @Override + public void run() { + log.info(this.getServiceName() + " service started"); + + while (!this.isStopped()) { + try { + this.waitForRunning(WATCH_INTERVAL); + + boolean allFileChanged = true; + for (int i = 0; i < watchFiles.length; i++) { + HashCode newHash = Files.hash(new File(watchFiles[i]), Hashing.md5()); + if (!newHash.equals(fileCurrentHash[i])) { + isFileChangedFlag[i] = true; + fileCurrentHash[i] = newHash; + } + allFileChanged = allFileChanged && isFileChangedFlag[i]; + } + + if (allFileChanged) { + listener.onChanged(); + for (int i = 0; i < isFileChangedFlag.length; i++) { + isFileChangedFlag[i] = false; + } + } + } catch (Exception e) { + log.warn(this.getServiceName() + " service has exception. ", e); + } + } + log.info(this.getServiceName() + " service end"); + } + + public interface Listener { + /** + * Will be called when the target files are changed + */ + void onChanged(); + } +} diff --git a/srvutil/src/main/test/org/apache/rocketmq/srvutil/FileWatchServiceTest.java b/srvutil/src/main/test/org/apache/rocketmq/srvutil/FileWatchServiceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..0d651480ada2c06cc8b9b3ef454731bfb7d31e50 --- /dev/null +++ b/srvutil/src/main/test/org/apache/rocketmq/srvutil/FileWatchServiceTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.srvutil; + +import java.io.File; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +@RunWith(MockitoJUnitRunner.class) +public class FileWatchServiceTest { + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void watchSingleFile() throws IOException, InterruptedException { + File file = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService(new String[] {file.getAbsolutePath()}, new FileWatchService.Listener() { + @Override + public void onChanged() { + waitSemaphore.release(); + } + }); + fileWatchService.start(); + modifyFile(file); + boolean result = waitSemaphore.tryAcquire(1, 100, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + + } + + @Test + public void watchTwoFiles_ModifyOne() throws IOException, InterruptedException { + File fileA = tempFolder.newFile(); + File fileB = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService( + new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()}, + new FileWatchService.Listener() { + @Override + public void onChanged() { + waitSemaphore.release(); + } + }); + fileWatchService.start(); + modifyFile(fileA); + boolean result = waitSemaphore.tryAcquire(1, 100, TimeUnit.MILLISECONDS); + assertThat(result).isFalse(); + } + + @Test + public void watchTwoFiles() throws IOException, InterruptedException { + File fileA = tempFolder.newFile(); + File fileB = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService( + new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()}, + new FileWatchService.Listener() { + @Override + public void onChanged() { + waitSemaphore.release(); + } + }); + fileWatchService.start(); + modifyFile(fileA); + modifyFile(fileB); + boolean result = waitSemaphore.tryAcquire(1, 100, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + } + + private static void modifyFile(File file) { + try { + PrintWriter out = new PrintWriter(file); + out.println(System.currentTimeMillis()); + out.flush(); + out.close(); + } catch (IOException ignore) { + } + } +} \ No newline at end of file