[FLINK-7878] Hide GpuResource in ResourceSpec

上级 5b9ac950
......@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nonnull;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
......@@ -32,9 +33,9 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* Describe the different resource factors of the operator with UDF.
*
* The state backend provides the method to estimate memory usages based on state size in the resource.
* <p>The state backend provides the method to estimate memory usages based on state size in the resource.
*
* Resource provides {@link #merge(ResourceSpec)} method for chained operators when generating job graph.
* <p>Resource provides {@link #merge(ResourceSpec)} method for chained operators when generating job graph.
*
* <p>Resource provides {@link #lessThanOrEqual(ResourceSpec)} method to compare these fields in sequence:
* <ol>
......@@ -53,21 +54,21 @@ public class ResourceSpec implements Serializable {
public static final ResourceSpec DEFAULT = new ResourceSpec(0, 0, 0, 0, 0);
private static String GPU_NAME = "GPU";
private static final String GPU_NAME = "GPU";
/** How many cpu cores are needed, use double so we can specify cpu like 0.1 */
/** How many cpu cores are needed, use double so we can specify cpu like 0.1. */
private final double cpuCores;
/** How many java heap memory in mb are needed */
/** How many java heap memory in mb are needed. */
private final int heapMemoryInMB;
/** How many nio direct memory in mb are needed */
/** How many nio direct memory in mb are needed. */
private final int directMemoryInMB;
/** How many native memory in mb are needed */
/** How many native memory in mb are needed. */
private final int nativeMemoryInMB;
/** How many state size in mb are used */
/** How many state size in mb are used. */
private final int stateSizeInMB;
private final Map<String, Resource> extendedResources = new HashMap<>(1);
......@@ -239,8 +240,13 @@ public class ResourceSpec implements Serializable {
'}';
}
public static Builder newBuilder() { return new Builder(); }
public static Builder newBuilder() {
return new Builder();
}
/**
* Builder for the {@link ResourceSpec}.
*/
public static class Builder {
public double cpuCores;
......@@ -275,28 +281,40 @@ public class ResourceSpec implements Serializable {
return this;
}
public Builder setGPUResource(GPUResource gpuResource) {
this.gpuResource = gpuResource;
public Builder setGPUResource(double gpus) {
this.gpuResource = new GPUResource(gpus);
return this;
}
public ResourceSpec build() {
return new ResourceSpec(cpuCores, heapMemoryInMB, directMemoryInMB, nativeMemoryInMB, stateSizeInMB, gpuResource);
return new ResourceSpec(
cpuCores,
heapMemoryInMB,
directMemoryInMB,
nativeMemoryInMB,
stateSizeInMB,
gpuResource);
}
}
public static abstract class Resource implements Serializable {
/**
* Base class for additional resources one can specify.
*/
protected abstract static class Resource implements Serializable {
private static final long serialVersionUID = 1L;
/**
* Enum defining how resources are aggregated.
*/
public enum ResourceAggregateType {
/**
* Denotes keeping the sum of the values with same name when merging two resource specs for operator chaining
* Denotes keeping the sum of the values with same name when merging two resource specs for operator chaining.
*/
AGGREGATE_TYPE_SUM,
/**
* Denotes keeping the max of the values with same name when merging two resource specs for operator chaining
* Denotes keeping the max of the values with same name when merging two resource specs for operator chaining.
*/
AGGREGATE_TYPE_MAX
}
......@@ -305,7 +323,7 @@ public class ResourceSpec implements Serializable {
private final double value;
final private ResourceAggregateType type;
private final ResourceAggregateType type;
public Resource(String name, double value, ResourceAggregateType type) {
this.name = checkNotNull(name);
......@@ -348,14 +366,14 @@ public class ResourceSpec implements Serializable {
@Override
public int hashCode() {
int result = name != null ? name.hashCode() : 0;
int result = name.hashCode();
result = 31 * result + type.ordinal();
result = 31 * result + (int)value;
result = 31 * result + (int) value;
return result;
}
/**
* Create a resource of the same resource type
* Create a resource of the same resource type.
*
* @param value The value of the resource
* @param type The aggregate type of the resource
......@@ -369,6 +387,8 @@ public class ResourceSpec implements Serializable {
*/
public static class GPUResource extends Resource {
private static final long serialVersionUID = -2276080061777135142L;
public GPUResource(double value) {
this(value, ResourceAggregateType.AGGREGATE_TYPE_SUM);
}
......
......@@ -20,12 +20,12 @@ package org.apache.flink.api.common.operators;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.TestLogger;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Tests for ResourceSpec class, including its all public api: isValid, lessThanOrEqual, equals, hashCode and merge.
......@@ -40,14 +40,14 @@ public class ResourceSpecTest extends TestLogger {
rs = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1)).
setGPUResource(1).
build();
assertTrue(rs.isValid());
rs = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(-1)).
setGPUResource(-1).
build();
assertFalse(rs.isValid());
}
......@@ -62,7 +62,7 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs3 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1)).
setGPUResource(1.1).
build();
assertTrue(rs1.lessThanOrEqual(rs3));
assertFalse(rs3.lessThanOrEqual(rs1));
......@@ -70,7 +70,7 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs4 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
assertFalse(rs4.lessThanOrEqual(rs3));
assertTrue(rs3.lessThanOrEqual(rs4));
......@@ -86,19 +86,19 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs3 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
ResourceSpec rs4 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1)).
setGPUResource(1).
build();
assertFalse(rs3.equals(rs4));
ResourceSpec rs5 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
assertTrue(rs3.equals(rs5));
}
......@@ -112,28 +112,21 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs3 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
ResourceSpec rs4 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1)).
setGPUResource(1).
build();
assertFalse(rs3.hashCode() == rs4.hashCode());
ResourceSpec rs5 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2)).
setGPUResource(2.2).
build();
assertEquals(rs3.hashCode(), rs5.hashCode());
ResourceSpec rs6 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(2.2, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)).
build();
assertFalse(rs6.hashCode() == rs5.hashCode());
}
@Test
......@@ -141,7 +134,7 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs1 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1)).
setGPUResource(1.1).
build();
ResourceSpec rs2 = ResourceSpec.newBuilder().setCpuCores(1.0).setHeapMemoryInMB(100).build();
......@@ -150,26 +143,6 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs4 = rs1.merge(rs3);
assertEquals(2.2, rs4.getGPUResource(), 0.000001);
ResourceSpec rs5 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)).
build();
try {
rs4.merge(rs5);
fail("Merge with different aggregate type should fail");
} catch (IllegalArgumentException ignored) {
}
ResourceSpec rs6 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.5, ResourceSpec.Resource.ResourceAggregateType.AGGREGATE_TYPE_MAX)).
build();
ResourceSpec rs7 = rs5.merge(rs6);
assertEquals(1.5, rs7.getGPUResource(), 0.000001);
}
@Test
......@@ -177,7 +150,7 @@ public class ResourceSpecTest extends TestLogger {
ResourceSpec rs1 = ResourceSpec.newBuilder().
setCpuCores(1.0).
setHeapMemoryInMB(100).
setGPUResource(new ResourceSpec.GPUResource(1.1)).
setGPUResource(1.1).
build();
byte[] buffer = InstantiationUtil.serializeObject(rs1);
ResourceSpec rs2 = InstantiationUtil.deserializeObject(buffer, ClassLoader.getSystemClassLoader());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册