提交 c7fea667 编写于 作者: Y yukon 提交者: von gosling

[ROCKETMQ-335] Reload server certificate, private key and root ca when these are changed (#207)

上级 69043c0d
......@@ -73,11 +73,14 @@ import org.apache.rocketmq.common.protocol.body.TopicConfigSerializeWrapper;
import org.apache.rocketmq.common.stats.MomentStatsItem;
import org.apache.rocketmq.remoting.RPCHook;
import org.apache.rocketmq.remoting.RemotingServer;
import org.apache.rocketmq.remoting.common.TlsMode;
import org.apache.rocketmq.remoting.netty.NettyClientConfig;
import org.apache.rocketmq.remoting.netty.NettyRemotingServer;
import org.apache.rocketmq.remoting.netty.NettyRequestProcessor;
import org.apache.rocketmq.remoting.netty.NettyServerConfig;
import org.apache.rocketmq.remoting.netty.RequestTask;
import org.apache.rocketmq.remoting.netty.TlsSystemConfig;
import org.apache.rocketmq.srvutil.FileWatchService;
import org.apache.rocketmq.store.DefaultMessageStore;
import org.apache.rocketmq.store.MessageArrivingListener;
import org.apache.rocketmq.store.MessageStore;
......@@ -136,6 +139,7 @@ public class BrokerController {
private InetSocketAddress storeHost;
private BrokerFastFailure brokerFastFailure;
private Configuration configuration;
private FileWatchService fileWatchService;
public BrokerController(
final BrokerConfig brokerConfig,
......@@ -387,6 +391,45 @@ public class BrokerController {
}
}, 1000 * 10, 1000 * 60, TimeUnit.MILLISECONDS);
}
if (TlsSystemConfig.tlsMode != TlsMode.DISABLED) {
// Register a listener to reload SslContext
try {
fileWatchService = new FileWatchService(
new String[] {
TlsSystemConfig.tlsServerCertPath,
TlsSystemConfig.tlsServerKeyPath,
TlsSystemConfig.tlsServerTrustCertPath
},
new FileWatchService.Listener() {
boolean certChanged, keyChanged = false;
@Override
public void onChanged(String path) {
if (path.equals(TlsSystemConfig.tlsServerTrustCertPath)) {
log.info("The trust certificate changed, reload the ssl context");
reloadServerSslContext();
}
if (path.equals(TlsSystemConfig.tlsServerCertPath)) {
certChanged = true;
}
if (path.equals(TlsSystemConfig.tlsServerKeyPath)) {
keyChanged = true;
}
if (certChanged && keyChanged) {
log.info("The certificate and private key changed, reload the ssl context");
certChanged = keyChanged = false;
reloadServerSslContext();
}
}
private void reloadServerSslContext() {
((NettyRemotingServer) remotingServer).loadSslContext();
((NettyRemotingServer) fastRemotingServer).loadSslContext();
}
});
} catch (Exception e) {
log.warn("FileWatchService created error, can't load the certificate dynamically");
}
}
}
return result;
......@@ -594,6 +637,10 @@ public class BrokerController {
this.fastRemotingServer.shutdown();
}
if (this.fileWatchService != null) {
this.fileWatchService.shutdown();
}
if (this.messageStore != null) {
this.messageStore.shutdown();
}
......@@ -662,6 +709,10 @@ public class BrokerController {
this.fastRemotingServer.start();
}
if (this.fileWatchService != null) {
this.fileWatchService.start();
}
if (this.brokerOuterAPI != null) {
this.brokerOuterAPI.start();
}
......
......@@ -30,8 +30,11 @@ import org.apache.rocketmq.namesrv.processor.DefaultRequestProcessor;
import org.apache.rocketmq.namesrv.routeinfo.BrokerHousekeepingService;
import org.apache.rocketmq.namesrv.routeinfo.RouteInfoManager;
import org.apache.rocketmq.remoting.RemotingServer;
import org.apache.rocketmq.remoting.common.TlsMode;
import org.apache.rocketmq.remoting.netty.NettyRemotingServer;
import org.apache.rocketmq.remoting.netty.NettyServerConfig;
import org.apache.rocketmq.remoting.netty.TlsSystemConfig;
import org.apache.rocketmq.srvutil.FileWatchService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -54,6 +57,7 @@ public class NamesrvController {
private ExecutorService remotingExecutor;
private Configuration configuration;
private FileWatchService fileWatchService;
public NamesrvController(NamesrvConfig namesrvConfig, NettyServerConfig nettyServerConfig) {
this.namesrvConfig = namesrvConfig;
......@@ -95,6 +99,44 @@ public class NamesrvController {
}
}, 1, 10, TimeUnit.MINUTES);
if (TlsSystemConfig.tlsMode != TlsMode.DISABLED) {
// Register a listener to reload SslContext
try {
fileWatchService = new FileWatchService(
new String[] {
TlsSystemConfig.tlsServerCertPath,
TlsSystemConfig.tlsServerKeyPath,
TlsSystemConfig.tlsServerTrustCertPath
},
new FileWatchService.Listener() {
boolean certChanged, keyChanged = false;
@Override
public void onChanged(String path) {
if (path.equals(TlsSystemConfig.tlsServerTrustCertPath)) {
log.info("The trust certificate changed, reload the ssl context");
reloadServerSslContext();
}
if (path.equals(TlsSystemConfig.tlsServerCertPath)) {
certChanged = true;
}
if (path.equals(TlsSystemConfig.tlsServerKeyPath)) {
keyChanged = true;
}
if (certChanged && keyChanged) {
log.info("The certificate and private key changed, reload the ssl context");
certChanged = keyChanged = false;
reloadServerSslContext();
}
}
private void reloadServerSslContext() {
((NettyRemotingServer) remotingServer).loadSslContext();
}
});
} catch (Exception e) {
log.warn("FileWatchService created error, can't load the certificate dynamically");
}
}
return true;
}
......@@ -111,12 +153,20 @@ public class NamesrvController {
public void start() throws Exception {
this.remotingServer.start();
if (this.fileWatchService != null) {
this.fileWatchService.start();
}
}
public void shutdown() {
this.remotingServer.shutdown();
this.remotingExecutor.shutdown();
this.scheduledExecutorService.shutdown();
if (this.fileWatchService != null) {
this.fileWatchService.shutdown();
}
}
public NamesrvConfig getNamesrvConfig() {
......
......@@ -93,7 +93,7 @@ public abstract class NettyRemotingAbstract {
/**
* SSL context via which to create {@link SslHandler}.
*/
protected SslContext sslContext;
protected volatile SslContext sslContext;
/**
* Constructor, specifying capacity of one-way and asynchronous semaphores.
......
......@@ -139,6 +139,10 @@ public class NettyRemotingServer extends NettyRemotingAbstract implements Remoti
});
}
loadSslContext();
}
public void loadSslContext() {
TlsMode tlsMode = TlsSystemConfig.tlsMode;
log.info("Server is running in TLS {} mode", tlsMode.getName());
......
......@@ -25,6 +25,7 @@ import java.io.PrintWriter;
import org.apache.rocketmq.remoting.common.TlsMode;
import org.apache.rocketmq.remoting.exception.RemotingSendRequestException;
import org.apache.rocketmq.remoting.netty.NettyClientConfig;
import org.apache.rocketmq.remoting.netty.NettyRemotingServer;
import org.apache.rocketmq.remoting.netty.TlsHelper;
import org.apache.rocketmq.remoting.protocol.LanguageCode;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
......@@ -134,6 +135,9 @@ public class TlsTest {
clientConfig.setUseTLS(false);
} else if ("serverRejectsSSLClient".equals(name.getMethodName())) {
tlsMode = TlsMode.DISABLED;
} else if ("reloadSslContextForServer".equals(name.getMethodName())) {
tlsClientAuthServer = false;
tlsServerNeedClientAuth = "none";
}
remotingServer = RemotingServerTest.createRemotingServer();
......@@ -156,6 +160,26 @@ public class TlsTest {
requestThenAssertResponse();
}
@Test
public void reloadSslContextForServer() throws Exception {
requestThenAssertResponse();
//Use new cert and private key
tlsClientKeyPath = getCertsPath("badClient.key");
tlsClientCertPath = getCertsPath("badClient.pem");
((NettyRemotingServer) remotingServer).loadSslContext();
//Request Again
requestThenAssertResponse();
//Start another client
NettyClientConfig clientConfig = new NettyClientConfig();
clientConfig.setUseTLS(true);
RemotingClient remotingClient = RemotingServerTest.createRemotingClient(clientConfig);
requestThenAssertResponse(remotingClient);
}
@Test
public void serverNotNeedClientAuth() throws Exception {
requestThenAssertResponse();
......@@ -289,6 +313,10 @@ public class TlsTest {
}
private void requestThenAssertResponse() throws Exception {
requestThenAssertResponse(remotingClient);
}
private void requestThenAssertResponse(RemotingClient remotingClient) throws Exception {
RemotingCommand response = remotingClient.invokeSync("localhost:8888", createRequest(), 1000 * 3);
assertTrue(response != null);
assertThat(response.getLanguage()).isEqualTo(LanguageCode.JAVA);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.srvutil;
import com.google.common.base.Strings;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import org.apache.rocketmq.common.ServiceThread;
import org.apache.rocketmq.common.UtilAll;
import org.apache.rocketmq.common.constant.LoggerName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class FileWatchService extends ServiceThread {
private static final Logger log = LoggerFactory.getLogger(LoggerName.COMMON_LOGGER_NAME);
private final List<String> watchFiles;
private final List<String> fileCurrentHash;
private final Listener listener;
private static final int WATCH_INTERVAL = 500;
private MessageDigest md = MessageDigest.getInstance("MD5");
public FileWatchService(final String[] watchFiles,
final Listener listener) throws Exception {
this.listener = listener;
this.watchFiles = new ArrayList<>();
this.fileCurrentHash = new ArrayList<>();
for (int i = 0; i < watchFiles.length; i++) {
if (!Strings.isNullOrEmpty(watchFiles[i]) && new File(watchFiles[i]).exists()) {
this.watchFiles.add(watchFiles[i]);
this.fileCurrentHash.add(hash(watchFiles[i]));
}
}
}
@Override
public String getServiceName() {
return "FileWatchService";
}
@Override
public void run() {
log.info(this.getServiceName() + " service started");
while (!this.isStopped()) {
try {
this.waitForRunning(WATCH_INTERVAL);
for (int i = 0; i < watchFiles.size(); i++) {
String newHash;
try {
newHash = hash(watchFiles.get(i));
} catch (Exception ignored) {
log.warn(this.getServiceName() + " service has exception when calculate the file hash. ", ignored);
continue;
}
if (!newHash.equals(fileCurrentHash.get(i))) {
fileCurrentHash.set(i, newHash);
listener.onChanged(watchFiles.get(i));
}
}
} catch (Exception e) {
log.warn(this.getServiceName() + " service has exception. ", e);
}
}
log.info(this.getServiceName() + " service end");
}
private String hash(String filePath) throws IOException, NoSuchAlgorithmException {
Path path = Paths.get(filePath);
md.update(Files.readAllBytes(path));
byte[] hash = md.digest();
return UtilAll.bytes2string(hash);
}
public interface Listener {
/**
* Will be called when the target files are changed
* @param path the changed file path
*/
void onChanged(String path);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.srvutil;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import static org.assertj.core.api.Assertions.assertThat;
@RunWith(MockitoJUnitRunner.class)
public class FileWatchServiceTest {
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
@Test
public void watchSingleFile() throws Exception {
final File file = tempFolder.newFile();
final Semaphore waitSemaphore = new Semaphore(0);
FileWatchService fileWatchService = new FileWatchService(new String[] {file.getAbsolutePath()}, new FileWatchService.Listener() {
@Override
public void onChanged(String path) {
assertThat(file.getAbsolutePath()).isEqualTo(path);
waitSemaphore.release();
}
});
fileWatchService.start();
modifyFile(file);
boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isTrue();
}
@Test
public void watchSingleFile_FileDeleted() throws Exception {
File file = tempFolder.newFile();
final Semaphore waitSemaphore = new Semaphore(0);
FileWatchService fileWatchService = new FileWatchService(new String[] {file.getAbsolutePath()},
new FileWatchService.Listener() {
@Override
public void onChanged(String path) {
waitSemaphore.release();
}
});
fileWatchService.start();
file.delete();
boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isFalse();
file.createNewFile();
modifyFile(file);
result = waitSemaphore.tryAcquire(1, 2000, TimeUnit.MILLISECONDS);
assertThat(result).isTrue();
}
@Test
public void watchTwoFile_FileDeleted() throws Exception {
File fileA = tempFolder.newFile();
File fileB = tempFolder.newFile();
final Semaphore waitSemaphore = new Semaphore(0);
FileWatchService fileWatchService = new FileWatchService(
new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()},
new FileWatchService.Listener() {
@Override
public void onChanged(String path) {
waitSemaphore.release();
}
});
fileWatchService.start();
fileA.delete();
boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isFalse();
modifyFile(fileB);
result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isTrue();
fileA.createNewFile();
modifyFile(fileA);
result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isTrue();
}
@Test
public void watchTwoFiles_ModifyOne() throws Exception {
final File fileA = tempFolder.newFile();
File fileB = tempFolder.newFile();
final Semaphore waitSemaphore = new Semaphore(0);
FileWatchService fileWatchService = new FileWatchService(
new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()},
new FileWatchService.Listener() {
@Override
public void onChanged(String path) {
assertThat(path).isEqualTo(fileA.getAbsolutePath());
waitSemaphore.release();
}
});
fileWatchService.start();
modifyFile(fileA);
boolean result = waitSemaphore.tryAcquire(1, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isTrue();
}
@Test
public void watchTwoFiles() throws Exception {
File fileA = tempFolder.newFile();
File fileB = tempFolder.newFile();
final Semaphore waitSemaphore = new Semaphore(0);
FileWatchService fileWatchService = new FileWatchService(
new String[] {fileA.getAbsolutePath(), fileB.getAbsolutePath()},
new FileWatchService.Listener() {
@Override
public void onChanged(String path) {
waitSemaphore.release();
}
});
fileWatchService.start();
modifyFile(fileA);
modifyFile(fileB);
boolean result = waitSemaphore.tryAcquire(2, 1000, TimeUnit.MILLISECONDS);
assertThat(result).isTrue();
}
private static void modifyFile(File file) {
try {
PrintWriter out = new PrintWriter(file);
out.println(System.nanoTime());
out.flush();
out.close();
} catch (IOException ignore) {
}
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册