diff --git a/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java b/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java index 0a6f0b45e9de20f282880b854ad209ec97f31cf3..8891bd322e3f05128ba1d983be1fa05348fd366c 100644 --- a/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java +++ b/broker/src/main/java/org/apache/rocketmq/broker/BrokerController.java @@ -73,11 +73,14 @@ import org.apache.rocketmq.common.protocol.body.TopicConfigSerializeWrapper; import org.apache.rocketmq.common.stats.MomentStatsItem; import org.apache.rocketmq.remoting.RPCHook; import org.apache.rocketmq.remoting.RemotingServer; +import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.netty.NettyClientConfig; import org.apache.rocketmq.remoting.netty.NettyRemotingServer; import org.apache.rocketmq.remoting.netty.NettyRequestProcessor; import org.apache.rocketmq.remoting.netty.NettyServerConfig; import org.apache.rocketmq.remoting.netty.RequestTask; +import org.apache.rocketmq.remoting.netty.TlsSystemConfig; +import org.apache.rocketmq.srvutil.FileWatchService; import org.apache.rocketmq.store.DefaultMessageStore; import org.apache.rocketmq.store.MessageArrivingListener; import org.apache.rocketmq.store.MessageStore; @@ -136,6 +139,7 @@ public class BrokerController { private InetSocketAddress storeHost; private BrokerFastFailure brokerFastFailure; private Configuration configuration; + private FileWatchService fileWatchService; public BrokerController( final BrokerConfig brokerConfig, @@ -387,6 +391,45 @@ public class BrokerController { } }, 1000 * 10, 1000 * 60, TimeUnit.MILLISECONDS); } + + if (TlsSystemConfig.tlsMode != TlsMode.DISABLED) { + // Register a listener to reload SslContext + try { + fileWatchService = new FileWatchService( + new String[] { + TlsSystemConfig.tlsServerCertPath, + TlsSystemConfig.tlsServerKeyPath, + TlsSystemConfig.tlsServerTrustCertPath + }, + new FileWatchService.Listener() { + boolean certChanged, keyChanged = false; + @Override + public void onChanged(String path) { + if (path.equals(TlsSystemConfig.tlsServerTrustCertPath)) { + log.info("The trust certificate changed, reload the ssl context"); + reloadServerSslContext(); + } + if (path.equals(TlsSystemConfig.tlsServerCertPath)) { + certChanged = true; + } + if (path.equals(TlsSystemConfig.tlsServerKeyPath)) { + keyChanged = true; + } + if (certChanged && keyChanged) { + log.info("The certificate and private key changed, reload the ssl context"); + certChanged = keyChanged = false; + reloadServerSslContext(); + } + } + private void reloadServerSslContext() { + ((NettyRemotingServer) remotingServer).loadSslContext(); + ((NettyRemotingServer) fastRemotingServer).loadSslContext(); + } + }); + } catch (Exception e) { + log.warn("FileWatchService created error, can't load the certificate dynamically"); + } + } } return result; @@ -594,6 +637,10 @@ public class BrokerController { this.fastRemotingServer.shutdown(); } + if (this.fileWatchService != null) { + this.fileWatchService.shutdown(); + } + if (this.messageStore != null) { this.messageStore.shutdown(); } @@ -662,6 +709,10 @@ public class BrokerController { this.fastRemotingServer.start(); } + if (this.fileWatchService != null) { + this.fileWatchService.start(); + } + if (this.brokerOuterAPI != null) { this.brokerOuterAPI.start(); } diff --git a/namesrv/src/main/java/org/apache/rocketmq/namesrv/NamesrvController.java b/namesrv/src/main/java/org/apache/rocketmq/namesrv/NamesrvController.java index 51b20b416dd5c4b10ff7b6014791295eb533b21e..2ed599c11c49549ac556027e2b7896848af9f362 100644 --- a/namesrv/src/main/java/org/apache/rocketmq/namesrv/NamesrvController.java +++ b/namesrv/src/main/java/org/apache/rocketmq/namesrv/NamesrvController.java @@ -30,8 +30,11 @@ import org.apache.rocketmq.namesrv.processor.DefaultRequestProcessor; import org.apache.rocketmq.namesrv.routeinfo.BrokerHousekeepingService; import org.apache.rocketmq.namesrv.routeinfo.RouteInfoManager; import org.apache.rocketmq.remoting.RemotingServer; +import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.netty.NettyRemotingServer; import org.apache.rocketmq.remoting.netty.NettyServerConfig; +import org.apache.rocketmq.remoting.netty.TlsSystemConfig; +import org.apache.rocketmq.srvutil.FileWatchService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,6 +57,7 @@ public class NamesrvController { private ExecutorService remotingExecutor; private Configuration configuration; + private FileWatchService fileWatchService; public NamesrvController(NamesrvConfig namesrvConfig, NettyServerConfig nettyServerConfig) { this.namesrvConfig = namesrvConfig; @@ -95,6 +99,44 @@ public class NamesrvController { } }, 1, 10, TimeUnit.MINUTES); + if (TlsSystemConfig.tlsMode != TlsMode.DISABLED) { + // Register a listener to reload SslContext + try { + fileWatchService = new FileWatchService( + new String[] { + TlsSystemConfig.tlsServerCertPath, + TlsSystemConfig.tlsServerKeyPath, + TlsSystemConfig.tlsServerTrustCertPath + }, + new FileWatchService.Listener() { + boolean certChanged, keyChanged = false; + @Override + public void onChanged(String path) { + if (path.equals(TlsSystemConfig.tlsServerTrustCertPath)) { + log.info("The trust certificate changed, reload the ssl context"); + reloadServerSslContext(); + } + if (path.equals(TlsSystemConfig.tlsServerCertPath)) { + certChanged = true; + } + if (path.equals(TlsSystemConfig.tlsServerKeyPath)) { + keyChanged = true; + } + if (certChanged && keyChanged) { + log.info("The certificate and private key changed, reload the ssl context"); + certChanged = keyChanged = false; + reloadServerSslContext(); + } + } + private void reloadServerSslContext() { + ((NettyRemotingServer) remotingServer).loadSslContext(); + } + }); + } catch (Exception e) { + log.warn("FileWatchService created error, can't load the certificate dynamically"); + } + } + return true; } @@ -111,12 +153,20 @@ public class NamesrvController { public void start() throws Exception { this.remotingServer.start(); + + if (this.fileWatchService != null) { + this.fileWatchService.start(); + } } public void shutdown() { this.remotingServer.shutdown(); this.remotingExecutor.shutdown(); this.scheduledExecutorService.shutdown(); + + if (this.fileWatchService != null) { + this.fileWatchService.shutdown(); + } } public NamesrvConfig getNamesrvConfig() { diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java index 76752529af24f021c6860fb90866f629d4a8a75b..557ad5602b0bf961af7c40dbba6613a9ecba34ff 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java @@ -93,7 +93,7 @@ public abstract class NettyRemotingAbstract { /** * SSL context via which to create {@link SslHandler}. */ - protected SslContext sslContext; + protected volatile SslContext sslContext; /** * Constructor, specifying capacity of one-way and asynchronous semaphores. diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java index cd6ed4704f48a36413f425a4f54b2561ba8ef01d..c8709a501498d50bb408018a3f40848d169e2deb 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java @@ -139,6 +139,10 @@ public class NettyRemotingServer extends NettyRemotingAbstract implements Remoti }); } + loadSslContext(); + } + + public void loadSslContext() { TlsMode tlsMode = TlsSystemConfig.tlsMode; log.info("Server is running in TLS {} mode", tlsMode.getName()); diff --git a/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java b/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java index 5e516dd74bda5488afc01197eac0816a055983e7..13bb17282e30ec689f495fa4920eb555bae16fbc 100644 --- a/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java +++ b/remoting/src/test/java/org/apache/rocketmq/remoting/TlsTest.java @@ -25,6 +25,7 @@ import java.io.PrintWriter; import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.exception.RemotingSendRequestException; import org.apache.rocketmq.remoting.netty.NettyClientConfig; +import org.apache.rocketmq.remoting.netty.NettyRemotingServer; import org.apache.rocketmq.remoting.netty.TlsHelper; import org.apache.rocketmq.remoting.protocol.LanguageCode; import org.apache.rocketmq.remoting.protocol.RemotingCommand; @@ -134,6 +135,9 @@ public class TlsTest { clientConfig.setUseTLS(false); } else if ("serverRejectsSSLClient".equals(name.getMethodName())) { tlsMode = TlsMode.DISABLED; + } else if ("reloadSslContextForServer".equals(name.getMethodName())) { + tlsClientAuthServer = false; + tlsServerNeedClientAuth = "none"; } remotingServer = RemotingServerTest.createRemotingServer(); @@ -156,6 +160,26 @@ public class TlsTest { requestThenAssertResponse(); } + @Test + public void reloadSslContextForServer() throws Exception { + requestThenAssertResponse(); + + //Use new cert and private key + tlsClientKeyPath = getCertsPath("badClient.key"); + tlsClientCertPath = getCertsPath("badClient.pem"); + + ((NettyRemotingServer) remotingServer).loadSslContext(); + + //Request Again + requestThenAssertResponse(); + + //Start another client + NettyClientConfig clientConfig = new NettyClientConfig(); + clientConfig.setUseTLS(true); + RemotingClient remotingClient = RemotingServerTest.createRemotingClient(clientConfig); + requestThenAssertResponse(remotingClient); + } + @Test public void serverNotNeedClientAuth() throws Exception { requestThenAssertResponse(); @@ -289,6 +313,10 @@ public class TlsTest { } private void requestThenAssertResponse() throws Exception { + requestThenAssertResponse(remotingClient); + } + + private void requestThenAssertResponse(RemotingClient remotingClient) throws Exception { RemotingCommand response = remotingClient.invokeSync("localhost:8888", createRequest(), 1000 * 3); assertTrue(response != null); assertThat(response.getLanguage()).isEqualTo(LanguageCode.JAVA); diff --git a/srvutil/src/main/java/org/apache/rocketmq/srvutil/FileWatchService.java b/srvutil/src/main/java/org/apache/rocketmq/srvutil/FileWatchService.java new file mode 100644 index 0000000000000000000000000000000000000000..bc68d6a3ca6b3ab85cede9e951b2e4b881f19ace --- /dev/null +++ b/srvutil/src/main/java/org/apache/rocketmq/srvutil/FileWatchService.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.srvutil; + +import com.google.common.base.Strings; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.List; +import org.apache.rocketmq.common.ServiceThread; +import org.apache.rocketmq.common.UtilAll; +import org.apache.rocketmq.common.constant.LoggerName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FileWatchService extends ServiceThread { + private static final Logger log = LoggerFactory.getLogger(LoggerName.COMMON_LOGGER_NAME); + + private final List watchFiles; + private final List fileCurrentHash; + private final Listener listener; + private static final int WATCH_INTERVAL = 500; + private MessageDigest md = MessageDigest.getInstance("MD5"); + + public FileWatchService(final String[] watchFiles, + final Listener listener) throws Exception { + this.listener = listener; + this.watchFiles = new ArrayList<>(); + this.fileCurrentHash = new ArrayList<>(); + + for (int i = 0; i < watchFiles.length; i++) { + if (!Strings.isNullOrEmpty(watchFiles[i]) && new File(watchFiles[i]).exists()) { + this.watchFiles.add(watchFiles[i]); + this.fileCurrentHash.add(hash(watchFiles[i])); + } + } + } + + @Override + public String getServiceName() { + return "FileWatchService"; + } + + @Override + public void run() { + log.info(this.getServiceName() + " service started"); + + while (!this.isStopped()) { + try { + this.waitForRunning(WATCH_INTERVAL); + + for (int i = 0; i < watchFiles.size(); i++) { + String newHash; + try { + newHash = hash(watchFiles.get(i)); + } catch (Exception ignored) { + log.warn(this.getServiceName() + " service has exception when calculate the file hash. ", ignored); + continue; + } + if (!newHash.equals(fileCurrentHash.get(i))) { + fileCurrentHash.set(i, newHash); + listener.onChanged(watchFiles.get(i)); + } + } + } catch (Exception e) { + log.warn(this.getServiceName() + " service has exception. ", e); + } + } + log.info(this.getServiceName() + " service end"); + } + + private String hash(String filePath) throws IOException, NoSuchAlgorithmException { + Path path = Paths.get(filePath); + md.update(Files.readAllBytes(path)); + byte[] hash = md.digest(); + return UtilAll.bytes2string(hash); + } + + public interface Listener { + /** + * Will be called when the target files are changed + * @param path the changed file path + */ + void onChanged(String path); + } +} diff --git a/srvutil/src/main/test/org/apache/rocketmq/srvutil/FileWatchServiceTest.java b/srvutil/src/main/test/org/apache/rocketmq/srvutil/FileWatchServiceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..791abcf0c952d9431174f4a37ec2d48015f35128 --- /dev/null +++ b/srvutil/src/main/test/org/apache/rocketmq/srvutil/FileWatchServiceTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.srvutil; + +import java.io.File; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +@RunWith(MockitoJUnitRunner.class) +public class FileWatchServiceTest { + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void watchSingleFile() throws Exception { + final File file = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService(new String[] {file.getAbsolutePath()}, new FileWatchService.Listener() { + @Override + public void onChanged(String path) { + assertThat(file.getAbsolutePath()).isEqualTo(path); + waitSemaphore.release(); + } + }); + fileWatchService.start(); + modifyFile(file); + boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + } + + @Test + public void watchSingleFile_FileDeleted() throws Exception { + File file = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService(new String[] {file.getAbsolutePath()}, + new FileWatchService.Listener() { + @Override + public void onChanged(String path) { + waitSemaphore.release(); + } + }); + fileWatchService.start(); + file.delete(); + boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isFalse(); + file.createNewFile(); + modifyFile(file); + result = waitSemaphore.tryAcquire(1, 2000, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + } + + @Test + public void watchTwoFile_FileDeleted() throws Exception { + File fileA = tempFolder.newFile(); + File fileB = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService( + new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()}, + new FileWatchService.Listener() { + @Override + public void onChanged(String path) { + waitSemaphore.release(); + } + }); + fileWatchService.start(); + fileA.delete(); + boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isFalse(); + modifyFile(fileB); + result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + fileA.createNewFile(); + modifyFile(fileA); + result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + } + + @Test + public void watchTwoFiles_ModifyOne() throws Exception { + final File fileA = tempFolder.newFile(); + File fileB = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService( + new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()}, + new FileWatchService.Listener() { + @Override + public void onChanged(String path) { + assertThat(path).isEqualTo(fileA.getAbsolutePath()); + waitSemaphore.release(); + } + }); + fileWatchService.start(); + modifyFile(fileA); + boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + } + + @Test + public void watchTwoFiles() throws Exception { + File fileA = tempFolder.newFile(); + File fileB = tempFolder.newFile(); + final Semaphore waitSemaphore = new Semaphore(0); + FileWatchService fileWatchService = new FileWatchService( + new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()}, + new FileWatchService.Listener() { + @Override + public void onChanged(String path) { + waitSemaphore.release(); + } + }); + fileWatchService.start(); + modifyFile(fileA); + modifyFile(fileB); + boolean result = waitSemaphore.tryAcquire(2, 1000, TimeUnit.MILLISECONDS); + assertThat(result).isTrue(); + } + + private static void modifyFile(File file) { + try { + PrintWriter out = new PrintWriter(file); + out.println(System.nanoTime()); + out.flush(); + out.close(); + } catch (IOException ignore) { + } + } +} \ No newline at end of file