未验证 提交 dac36482 编写于 作者: P Paul Lam 提交者: Till Rohrmann

[FLINK-11126][YARN][security] Filter out AMRMToken in the TaskManager credentials

This closes #7895.
上级 4f558e4f
......@@ -21,28 +21,50 @@ package org.apache.flink.yarn;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ResourceManagerOptions;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.runtime.clusterframework.ContaineredTaskManagerParameters;
import org.apache.flink.util.TestLogger;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.log4j.AppenderSkeleton;
import org.apache.log4j.Level;
import org.apache.log4j.spi.LoggingEvent;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* Tests for various utilities.
*/
public class UtilsTest extends TestLogger {
private static final Logger LOG = LoggerFactory.getLogger(UtilsTest.class);
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Test
public void testUberjarLocator() {
File dir = YarnTestBase.findFile("..", new YarnTestBase.RootDirFilenameFilter());
......@@ -136,6 +158,72 @@ public class UtilsTest extends TestLogger {
Assert.assertEquals(0, res.size());
}
@Test
public void testCreateTaskExecutorCredentials() throws Exception {
File root = temporaryFolder.getRoot();
File home = new File(root, "home");
boolean created = home.mkdir();
assertTrue(created);
Configuration flinkConf = new Configuration();
YarnConfiguration yarnConf = new YarnConfiguration();
Map<String, String> env = new HashMap<>();
env.put(YarnConfigKeys.ENV_APP_ID, "foo");
env.put(YarnConfigKeys.ENV_CLIENT_HOME_DIR, home.getAbsolutePath());
env.put(YarnConfigKeys.ENV_CLIENT_SHIP_FILES, "");
env.put(YarnConfigKeys.ENV_FLINK_CLASSPATH, "");
env.put(YarnConfigKeys.ENV_HADOOP_USER_NAME, "foo");
env.put(YarnConfigKeys.FLINK_JAR_PATH, root.toURI().toString());
env = Collections.unmodifiableMap(env);
File credentialFile = temporaryFolder.newFile("container_tokens");
final Text amRmTokenKind = AMRMTokenIdentifier.KIND_NAME;
final Text hdfsDelegationTokenKind = new Text("HDFS_DELEGATION_TOKEN");
final Text service = new Text("test-service");
Credentials amCredentials = new Credentials();
amCredentials.addToken(amRmTokenKind, new Token<>(new byte[4], new byte[4], amRmTokenKind, service));
amCredentials.addToken(hdfsDelegationTokenKind, new Token<>(new byte[4], new byte[4],
hdfsDelegationTokenKind, service));
amCredentials.writeTokenStorageFile(new org.apache.hadoop.fs.Path(credentialFile.getAbsolutePath()), yarnConf);
ContaineredTaskManagerParameters tmParams = new ContaineredTaskManagerParameters(64,
64, 16, 1, new HashMap<>(1));
Configuration taskManagerConf = new Configuration();
String workingDirectory = root.getAbsolutePath();
Class<?> taskManagerMainClass = YarnTaskExecutorRunner.class;
ContainerLaunchContext ctx;
final Map<String, String> originalEnv = System.getenv();
try {
Map<String, String> systemEnv = new HashMap<>(originalEnv);
systemEnv.put("HADOOP_TOKEN_FILE_LOCATION", credentialFile.getAbsolutePath());
CommonTestUtils.setEnv(systemEnv);
ctx = Utils.createTaskExecutorContext(flinkConf, yarnConf, env, tmParams,
taskManagerConf, workingDirectory, taskManagerMainClass, LOG);
} finally {
CommonTestUtils.setEnv(originalEnv);
}
Credentials credentials = new Credentials();
try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(ctx.getTokens().array()))) {
credentials.readTokenStorageStream(dis);
}
Collection<Token<? extends TokenIdentifier>> tokens = credentials.getAllTokens();
boolean hasHdfsDelegationToken = false;
boolean hasAmRmToken = false;
for (Token<? extends TokenIdentifier> token : tokens) {
if (token.getKind().equals(amRmTokenKind)) {
hasAmRmToken = true;
} else if (token.getKind().equals(hdfsDelegationTokenKind)) {
hasHdfsDelegationToken = true;
}
}
assertTrue(hasHdfsDelegationToken);
assertFalse(hasAmRmToken);
}
//
// --------------- Tools to test if a certain string has been logged with Log4j. -------------
// See : http://stackoverflow.com/questions/3717402/how-to-test-w-junit-that-warning-was-logged-w-log4j
......
......@@ -26,7 +26,10 @@ import org.apache.flink.runtime.security.modules.HadoopModule;
import org.apache.flink.test.util.SecureTestEnvironment;
import org.apache.flink.test.util.TestingSecurityContext;
import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceScheduler;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.fifo.FifoScheduler;
import org.hamcrest.Matchers;
......@@ -39,6 +42,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
/**
......@@ -116,6 +120,21 @@ public class YARNSessionFIFOSecuredITCase extends YARNSessionFIFOITCase {
"The JobManager and the TaskManager should both run with Kerberos.",
jobManagerRunsWithKerberos && taskManagerRunsWithKerberos,
Matchers.is(true));
final List<String> amRMTokens = Lists.newArrayList(AMRMTokenIdentifier.KIND_NAME.toString());
final String jobmanagerContainerId = getContainerIdByLogName("jobmanager.log");
final String taskmanagerContainerId = getContainerIdByLogName("taskmanager.log");
final boolean jobmanagerWithAmRmToken = verifyTokenKindInContainerCredentials(amRMTokens, jobmanagerContainerId);
final boolean taskmanagerWithAmRmToken = verifyTokenKindInContainerCredentials(amRMTokens, taskmanagerContainerId);
Assert.assertThat(
"The JobManager should have AMRMToken.",
jobmanagerWithAmRmToken,
Matchers.is(true));
Assert.assertThat(
"The TaskManager should not have AMRMToken.",
taskmanagerWithAmRmToken,
Matchers.is(false));
}
/* For secure cluster testing, it is enough to run only one test and override below test methods
......
......@@ -33,6 +33,9 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.service.Service;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.api.records.ContainerId;
......@@ -73,6 +76,7 @@ import java.nio.file.Paths;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
......@@ -464,9 +468,8 @@ public abstract class YarnTestBase extends TestLogger {
}
File f = new File(dir.getAbsolutePath() + "/" + name);
LOG.info("Searching in {}", f.getAbsolutePath());
try {
try (Scanner scanner = new Scanner(f)) {
Set<String> foundSet = new HashSet<>(mustHave.length);
Scanner scanner = new Scanner(f);
while (scanner.hasNextLine()) {
final String lineFromFile = scanner.nextLine();
for (String str : mustHave) {
......@@ -493,6 +496,53 @@ public abstract class YarnTestBase extends TestLogger {
}
}
public static boolean verifyTokenKindInContainerCredentials(final Collection<String> tokens, final String containerId)
throws IOException {
File cwd = new File("target/" + YARN_CONFIGURATION.get(TEST_CLUSTER_NAME_KEY));
if (!cwd.exists() || !cwd.isDirectory()) {
return false;
}
File containerTokens = findFile(cwd.getAbsolutePath(), new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
return name.equals(containerId + ".tokens");
}
});
if (containerTokens != null) {
LOG.info("Verifying tokens in {}", containerTokens.getAbsolutePath());
Credentials tmCredentials = Credentials.readTokenStorageFile(containerTokens, new Configuration());
Collection<Token<? extends TokenIdentifier>> userTokens = tmCredentials.getAllTokens();
Set<String> tokenKinds = new HashSet<>(4);
for (Token<? extends TokenIdentifier> token : userTokens) {
tokenKinds.add(token.getKind().toString());
}
return tokenKinds.containsAll(tokens);
} else {
LOG.warn("Unable to find credential file for container {}", containerId);
return false;
}
}
public static String getContainerIdByLogName(String logName) {
File cwd = new File("target/" + YARN_CONFIGURATION.get(TEST_CLUSTER_NAME_KEY));
File containerLog = findFile(cwd.getAbsolutePath(), new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
return name.equals(logName);
}
});
if (containerLog != null) {
return containerLog.getParentFile().getName();
} else {
throw new IllegalStateException("No container has log named " + logName);
}
}
public static void sleep(int time) {
try {
Thread.sleep(time);
......
......@@ -45,6 +45,7 @@ import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
import org.slf4j.Logger;
......@@ -567,7 +568,17 @@ public final class Utils {
new File(fileLocation),
HadoopUtils.getHadoopConfiguration(flinkConfig));
cred.writeTokenStorageToStream(dob);
// Filter out AMRMToken before setting the tokens to the TaskManager container context.
Credentials taskManagerCred = new Credentials();
Collection<Token<? extends TokenIdentifier>> userTokens = cred.getAllTokens();
for (Token<? extends TokenIdentifier> token : userTokens) {
if (!token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) {
final Text id = new Text(token.getIdentifier());
taskManagerCred.addToken(id, token);
}
}
taskManagerCred.writeTokenStorageToStream(dob);
ByteBuffer securityTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength());
ctx.setTokens(securityTokens);
} catch (Throwable t) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册