未验证 提交 4abfa2f3 编写于 作者: H Husaimawx 提交者: GitHub

UDF function: MasterRepair (#6892)

上级 68dedba8
......@@ -352,4 +352,65 @@ Output series:
|2020-01-01T00:00:28.000+08:00| 126.0|
|2020-01-01T00:00:30.000+08:00| 128.0|
+-----------------------------+-------------------------------------------------+
```
## MasterRepair
### Usage
This function is used to clean time series with master data.
**Name**: MasterRepair
**Input Series:** Support multiple input series. The types are are in INT32 / INT64 / FLOAT / DOUBLE.
**Parameters:**
+ `omega`: The window size. It is a non-negative integer whose unit is millisecond. By default, it will be estimated according to the distances of two tuples with various time differences.
+ `eta`: The distance threshold. It is a positive number. By default, it will be estimated according to the distance distribution of tuples in windows.
+ `k`: The number of neighbors in master data. It is a positive integer. By default, it will be estimated according to the tuple dis- tance of the k-th nearest neighbor in the master data.
+ `output_column`: The repaired column to output, defaults to 1 which means output the repair result of the first column.
**Output Series:** Output a single series. The type is the same as the input. This series is the input after repairing.
### Examples
Input series:
```
+-----------------------------+------------+------------+------------+------------+------------+------------+
| Time|root.test.t1|root.test.t2|root.test.t3|root.test.m1|root.test.m2|root.test.m3|
+-----------------------------+------------+------------+------------+------------+------------+------------+
|2021-07-01T12:00:01.000+08:00| 1704| 1154.55| 0.195| 1704| 1154.55| 0.195|
|2021-07-01T12:00:02.000+08:00| 1702| 1152.30| 0.193| 1702| 1152.30| 0.193|
|2021-07-01T12:00:03.000+08:00| 1702| 1148.65| 0.192| 1702| 1148.65| 0.192|
|2021-07-01T12:00:04.000+08:00| 1701| 1145.20| 0.194| 1701| 1145.20| 0.194|
|2021-07-01T12:00:07.000+08:00| 1703| 1150.55| 0.195| 1703| 1150.55| 0.195|
|2021-07-01T12:00:08.000+08:00| 1694| 1151.55| 0.193| 1704| 1151.55| 0.193|
|2021-07-01T12:01:09.000+08:00| 1705| 1153.55| 0.194| 1705| 1153.55| 0.194|
|2021-07-01T12:01:10.000+08:00| 1706| 1152.30| 0.190| 1706| 1152.30| 0.190|
+-----------------------------+------------+------------+------------+------------+------------+------------+
```
SQL for query:
```sql
select MasterRepair(t1,t2,t3,m1,m2,m3) from root.test
```
Output series:
```
+-----------------------------+-------------------------------------------------------------------------------------------+
| Time|MasterRepair(root.test.t1,root.test.t2,root.test.t3,root.test.m1,root.test.m2,root.test.m3)|
+-----------------------------+-------------------------------------------------------------------------------------------+
|2021-07-01T12:00:01.000+08:00| 1704|
|2021-07-01T12:00:02.000+08:00| 1702|
|2021-07-01T12:00:03.000+08:00| 1702|
|2021-07-01T12:00:04.000+08:00| 1701|
|2021-07-01T12:00:07.000+08:00| 1703|
|2021-07-01T12:00:08.000+08:00| 1704|
|2021-07-01T12:01:09.000+08:00| 1705|
|2021-07-01T12:01:10.000+08:00| 1706|
+-----------------------------+-------------------------------------------------------------------------------------------+
```
\ No newline at end of file
......@@ -343,4 +343,66 @@ select valuerepair(s1,'method'='LsGreedy') from root.test.d2
|2020-01-01T00:00:28.000+08:00| 126.0|
|2020-01-01T00:00:30.000+08:00| 128.0|
+-----------------------------+-------------------------------------------------+
```
## MasterRepair
### 函数简介
本函数实现基于主数据的时间序列数据修复。
**函数名:**MasterRepair
**输入序列:** 支持多个输入序列,类型为 INT32 / INT64 / FLOAT / DOUBLE。
**参数:**
- `omega`:算法窗口大小,非负整数(单位为毫秒), 在缺省情况下,算法根据不同时间差下的两个元组距离自动估计该参数。
- `eta`:算法距离阈值,正数, 在缺省情况下,算法根据窗口中元组的距离分布自动估计该参数。
- `k`:主数据中的近邻数量,正整数, 在缺省情况下,算法根据主数据中的k个近邻的元组距离自动估计该参数。
- `output_column`:输出列的序号,默认输出第一列的修复结果。
**输出序列:**输出单个序列,类型与输入数据中对应列的类型相同,序列为输入列修复后的结果。
### 使用示例
输入序列:
```
+-----------------------------+------------+------------+------------+------------+------------+------------+
| Time|root.test.t1|root.test.t2|root.test.t3|root.test.m1|root.test.m2|root.test.m3|
+-----------------------------+------------+------------+------------+------------+------------+------------+
|2021-07-01T12:00:01.000+08:00| 1704| 1154.55| 0.195| 1704| 1154.55| 0.195|
|2021-07-01T12:00:02.000+08:00| 1702| 1152.30| 0.193| 1702| 1152.30| 0.193|
|2021-07-01T12:00:03.000+08:00| 1702| 1148.65| 0.192| 1702| 1148.65| 0.192|
|2021-07-01T12:00:04.000+08:00| 1701| 1145.20| 0.194| 1701| 1145.20| 0.194|
|2021-07-01T12:00:07.000+08:00| 1703| 1150.55| 0.195| 1703| 1150.55| 0.195|
|2021-07-01T12:00:08.000+08:00| 1694| 1151.55| 0.193| 1704| 1151.55| 0.193|
|2021-07-01T12:01:09.000+08:00| 1705| 1153.55| 0.194| 1705| 1153.55| 0.194|
|2021-07-01T12:01:10.000+08:00| 1706| 1152.30| 0.190| 1706| 1152.30| 0.190|
+-----------------------------+------------+------------+------------+------------+------------+------------+
```
用于查询的 SQL 语句:
```sql
select MasterRepair(t1,t2,t3,m1,m2,m3) from root.test
```
输出序列:
```
+-----------------------------+-------------------------------------------------------------------------------------------+
| Time|MasterRepair(root.test.t1,root.test.t2,root.test.t3,root.test.m1,root.test.m2,root.test.m3)|
+-----------------------------+-------------------------------------------------------------------------------------------+
|2021-07-01T12:00:01.000+08:00| 1704|
|2021-07-01T12:00:02.000+08:00| 1702|
|2021-07-01T12:00:03.000+08:00| 1702|
|2021-07-01T12:00:04.000+08:00| 1701|
|2021-07-01T12:00:07.000+08:00| 1703|
|2021-07-01T12:00:08.000+08:00| 1704|
|2021-07-01T12:01:09.000+08:00| 1705|
|2021-07-01T12:01:10.000+08:00| 1706|
+-----------------------------+-------------------------------------------------------------------------------------------+
```
\ No newline at end of file
......@@ -74,7 +74,8 @@ public enum BuiltinTimeSeriesGeneratingFunctionEnum {
EQUAL_SIZE_BUCKET_AGG_SAMPLE("EQUAL_SIZE_BUCKET_AGG_SAMPLE"),
EQUAL_SIZE_BUCKET_M4_SAMPLE("EQUAL_SIZE_BUCKET_M4_SAMPLE"),
EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE("EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE"),
JEXL("JEXL");
JEXL("JEXL"),
MASTER_REPAIR("MASTER_REPAIR");
private final String functionName;
......
......@@ -1232,6 +1232,176 @@ public class IoTDBUDTFBuiltinFunctionIT {
}
}
@Test
public void testMasterRepair() {
// create time series with master data
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
statement.execute("SET STORAGE GROUP TO root.testMasterRepair");
statement.execute(
"CREATE TIMESERIES root.testMasterRepair.d1.s1 with datatype=FLOAT,encoding=PLAIN");
statement.execute(
"CREATE TIMESERIES root.testMasterRepair.d1.s2 with datatype=FLOAT,encoding=PLAIN");
statement.execute(
"CREATE TIMESERIES root.testMasterRepair.d1.s3 with datatype=FLOAT,encoding=PLAIN");
statement.execute(
"CREATE TIMESERIES root.testMasterRepair.d1.m1 with datatype=FLOAT,encoding=PLAIN");
statement.execute(
"CREATE TIMESERIES root.testMasterRepair.d1.m2 with datatype=FLOAT,encoding=PLAIN");
statement.execute(
"CREATE TIMESERIES root.testMasterRepair.d1.m3 with datatype=FLOAT,encoding=PLAIN");
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
String[] INSERT_SQL = {
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (1,1704,1154.55,0.195,1704,1154.55,0.195)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (2,1702,1152.30,0.193,1702,1152.30,0.193)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (3,1702,1148.65,0.192,1702,1148.65,0.192)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (4,1701,1145.20,0.194,1701,1145.20,0.194)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (7,1703,1150.55,0.195,1703,1150.55,0.195)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (8,1694,1151.55,0.193,1704,1151.55,0.193)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (9,1705,1153.55,0.194,1705,1153.55,0.194)",
"insert into root.testMasterRepair.d1(time, s1, s2, s3, m1, m2, m3) values (10,1706,1152.30,0.190,1706,1152.30,0.190)",
};
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
for (String dataGenerationSql : INSERT_SQL) {
statement.execute(dataGenerationSql);
}
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
int[] timestamps = {1, 2, 3, 4, 7, 8, 9, 10};
// test 1
double[] r1 = {1704.0, 1702.0, 1702.0, 1701.0, 1703.0, 1702.0, 1705.0, 1706.0};
try (ResultSet resultSet =
statement.executeQuery(
"select master_repair(s1,s2,s3,m1,m2,m3) from root.testMasterRepair.d1")) {
int columnCount = resultSet.getMetaData().getColumnCount();
assertEquals(1 + 1, columnCount);
for (int i = 0; i < timestamps.length; i++) {
resultSet.next();
long expectedTimestamp = timestamps[i];
long actualTimestamp = Long.parseLong(resultSet.getString(1));
assertEquals(expectedTimestamp, actualTimestamp);
double expectedResult = r1[i];
double actualResult = resultSet.getDouble(2);
double delta = 0.001;
assertEquals(expectedResult, actualResult, delta);
}
}
// test 2
double[] r2 = {1154.55, 1152.30, 1148.65, 1145.20, 1150.55, 1152.30, 1153.55, 1152.30};
try (ResultSet resultSet =
statement.executeQuery(
"select master_repair(s1,s2,s3,m1,m2,m3,'output_column'='2') from root.testMasterRepair.d1")) {
int columnCount = resultSet.getMetaData().getColumnCount();
assertEquals(1 + 1, columnCount);
for (int i = 0; i < timestamps.length; i++) {
resultSet.next();
long expectedTimestamp = timestamps[i];
long actualTimestamp = Long.parseLong(resultSet.getString(1));
assertEquals(expectedTimestamp, actualTimestamp);
double expectedResult = r2[i];
double actualResult = resultSet.getDouble(2);
double delta = 0.001;
assertEquals(expectedResult, actualResult, delta);
}
}
// test 3
double[] r3 = {0.195, 0.193, 0.192, 0.194, 0.195, 0.193, 0.194, 0.190};
try (ResultSet resultSet =
statement.executeQuery(
"select master_repair(s1,s2,s3,m1,m2,m3,'output_column'='3') from root.testMasterRepair.d1")) {
int columnCount = resultSet.getMetaData().getColumnCount();
assertEquals(1 + 1, columnCount);
for (int i = 0; i < timestamps.length; i++) {
resultSet.next();
long expectedTimestamp = timestamps[i];
long actualTimestamp = Long.parseLong(resultSet.getString(1));
assertEquals(expectedTimestamp, actualTimestamp);
double expectedResult = r3[i];
double actualResult = resultSet.getDouble(2);
double delta = 0.001;
assertEquals(expectedResult, actualResult, delta);
}
}
// test 4
double[] r4 = {1704.0, 1702.0, 1702.0, 1701.0, 1703.0, 1704.0, 1705.0, 1706.0};
try (ResultSet resultSet =
statement.executeQuery(
"select master_repair(s1,s2,s3,m1,m2,m3,'omega'='2','eta'='3.0','k'='5') from root.testMasterRepair.d1")) {
int columnCount = resultSet.getMetaData().getColumnCount();
assertEquals(1 + 1, columnCount);
for (int i = 0; i < timestamps.length; i++) {
resultSet.next();
long expectedTimestamp = timestamps[i];
long actualTimestamp = Long.parseLong(resultSet.getString(1));
assertEquals(expectedTimestamp, actualTimestamp);
double expectedResult = r4[i];
double actualResult = resultSet.getDouble(2);
double delta = 0.001;
assertEquals(expectedResult, actualResult, delta);
}
}
// test 5
double[] r5 = {1154.55, 1152.30, 1148.65, 1145.20, 1150.55, 1151.55, 1153.55, 1152.30};
try (ResultSet resultSet =
statement.executeQuery(
"select master_repair(s1,s2,s3,m1,m2,m3,'omega'='2','eta'='3.0','k'='5','output_column'='2') from root.testMasterRepair.d1")) {
int columnCount = resultSet.getMetaData().getColumnCount();
assertEquals(1 + 1, columnCount);
for (int i = 0; i < timestamps.length; i++) {
resultSet.next();
long expectedTimestamp = timestamps[i];
long actualTimestamp = Long.parseLong(resultSet.getString(1));
assertEquals(expectedTimestamp, actualTimestamp);
double expectedResult = r5[i];
double actualResult = resultSet.getDouble(2);
double delta = 0.001;
assertEquals(expectedResult, actualResult, delta);
}
}
// test 6
double[] r6 = {0.195, 0.193, 0.192, 0.194, 0.195, 0.193, 0.194, 0.190};
try (ResultSet resultSet =
statement.executeQuery(
"select master_repair(s1,s2,s3,m1,m2,m3,'omega'='2','eta'='3.0','k'='5','output_column'='3') from root.testMasterRepair.d1")) {
int columnCount = resultSet.getMetaData().getColumnCount();
assertEquals(1 + 1, columnCount);
for (int i = 0; i < timestamps.length; i++) {
resultSet.next();
long expectedTimestamp = timestamps[i];
long actualTimestamp = Long.parseLong(resultSet.getString(1));
assertEquals(expectedTimestamp, actualTimestamp);
double expectedResult = r6[i];
double actualResult = resultSet.getDouble(2);
double delta = 0.001;
assertEquals(expectedResult, actualResult, delta);
}
}
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}
@Test
public void testDeDup() {
String[] createSQLs =
......
......@@ -88,7 +88,8 @@ public enum BuiltinTimeSeriesGeneratingFunction {
EQUAL_SIZE_BUCKET_M4_SAMPLE("EQUAL_SIZE_BUCKET_M4_SAMPLE", UDTFEqualSizeBucketM4Sample.class),
EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE(
"EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE", UDTFEqualSizeBucketOutlierSample.class),
JEXL("JEXL", UDTFJexl.class);
JEXL("JEXL", UDTFJexl.class),
MASTER_REPAIR("MASTER_REPAIR", UDTFMasterRepair.class);
private final String functionName;
private final Class<?> functionClass;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iotdb.commons.udf.builtin;
import org.apache.iotdb.commons.udf.utils.MasterRepairUtil;
import org.apache.iotdb.udf.api.UDTF;
import org.apache.iotdb.udf.api.access.Row;
import org.apache.iotdb.udf.api.collector.PointCollector;
import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
import org.apache.iotdb.udf.api.customizer.strategy.RowByRowAccessStrategy;
import org.apache.iotdb.udf.api.type.Type;
import java.util.ArrayList;
public class UDTFMasterRepair implements UDTF {
private MasterRepairUtil masterRepairUtil;
private int outputColumn;
@Override
public void validate(UDFParameterValidator validator) throws Exception {
for (int i = 0; i < validator.getParameters().getChildExpressionsSize(); i++) {
validator.validateInputSeriesDataType(i, Type.DOUBLE, Type.FLOAT, Type.INT32, Type.INT64);
}
if (validator.getParameters().hasAttribute("omega")) {
validator.validate(
omega -> (int) omega >= 0,
"Parameter omega should be non-negative.",
validator.getParameters().getInt("omega"));
}
if (validator.getParameters().hasAttribute("eta")) {
validator.validate(
eta -> (double) eta > 0,
"Parameter eta should be larger than 0.",
validator.getParameters().getDouble("eta"));
}
if (validator.getParameters().hasAttribute("k")) {
validator.validate(
k -> (int) k > 0,
"Parameter k should be a positive integer.",
validator.getParameters().getInt("k"));
}
if (validator.getParameters().hasAttribute("output_column")) {
validator.validate(
outputColumn -> (int) outputColumn > 0,
"Parameter output_column should be a positive integer.",
validator.getParameters().getInt("output_column"));
}
}
@Override
public void beforeStart(UDFParameters parameters, UDTFConfigurations configurations)
throws Exception {
configurations.setAccessStrategy(new RowByRowAccessStrategy());
configurations.setOutputDataType(Type.DOUBLE);
int columnCnt = parameters.getDataTypes().size() / 2;
long omega = parameters.getLongOrDefault("omega", -1);
double eta = parameters.getDoubleOrDefault("eta", Double.NaN);
int k = parameters.getIntOrDefault("k", -1);
masterRepairUtil = new MasterRepairUtil(columnCnt, omega, eta, k);
outputColumn = parameters.getIntOrDefault("output_column", 1);
}
@Override
public void transform(Row row, PointCollector collector) throws Exception {
if (!masterRepairUtil.isNullRow(row)) {
masterRepairUtil.addRow(row);
}
}
@Override
public void terminate(PointCollector collector) throws Exception {
masterRepairUtil.repair();
ArrayList<Long> times = masterRepairUtil.getTime();
ArrayList<Double> column = masterRepairUtil.getCleanResultColumn(this.outputColumn);
for (int i = 0; i < column.size(); i++) {
collector.putDouble(times.get(i), column.get(i));
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iotdb.commons.udf.utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Stack;
import static java.lang.Math.min;
import static java.lang.Math.sqrt;
public class KDTreeUtil {
private Node kdTree;
private static class Node {
int partitionDimension;
double partitionValue;
ArrayList<Double> value;
boolean isLeaf = false;
Node left;
Node right;
// min value of each dimension
ArrayList<Double> min;
// max value of each dimension
ArrayList<Double> max;
}
public static KDTreeUtil build(ArrayList<ArrayList<Double>> input, int dimension) {
KDTreeUtil tree = new KDTreeUtil();
tree.kdTree = new Node();
tree.buildDetail(tree.kdTree, input, dimension);
return tree;
}
private void buildDetail(Node node, ArrayList<ArrayList<Double>> data, int dimensions) {
if (data.size() == 0) {
return;
}
if (data.size() == 1) {
node.isLeaf = true;
node.value = data.get(0);
return;
}
node.partitionDimension = -1;
double var = -1;
double tmpvar;
for (int i = 0; i < dimensions; i++) {
tmpvar = UtilZ.variance(data, i);
if (tmpvar > var) {
var = tmpvar;
node.partitionDimension = i;
}
}
if (var == 0d) {
node.isLeaf = true;
node.value = data.get(0);
return;
}
node.partitionValue = UtilZ.median(data, node.partitionDimension);
ArrayList<ArrayList<Double>> maxMin = UtilZ.maxMin(data, dimensions);
node.min = maxMin.get(0);
node.max = maxMin.get(1);
ArrayList<ArrayList<Double>> left = new ArrayList<>();
ArrayList<ArrayList<Double>> right = new ArrayList<>();
for (ArrayList<Double> d : data) {
if (d.get(node.partitionDimension) < node.partitionValue) {
left.add(d);
} else if (d.get(node.partitionDimension) > node.partitionValue) {
right.add(d);
}
}
for (ArrayList<Double> d : data) {
if (d.get(node.partitionDimension) == node.partitionValue) {
if (left.size() == 0) {
left.add(d);
} else {
right.add(d);
}
}
}
Node leftNode = new Node();
Node rightNode = new Node();
node.left = leftNode;
node.right = rightNode;
buildDetail(leftNode, left, dimensions);
buildDetail(rightNode, right, dimensions);
}
public ArrayList<Double> query(ArrayList<Double> input, double[] std) {
Node node = kdTree;
Stack<Node> stack = new Stack<>();
while (!node.isLeaf) {
if (input.get(node.partitionDimension) < node.partitionValue) {
stack.add(node.right);
node = node.left;
} else {
stack.push(node.left);
node = node.right;
}
}
double distance = UtilZ.distance(input, node.value, std);
ArrayList<Double> nearest = queryRec(input, distance, stack, std);
return nearest == null ? node.value : nearest;
}
public ArrayList<Double> queryRec(
ArrayList<Double> input, double distance, Stack<Node> stack, double[] std) {
ArrayList<Double> nearest = null;
Node node;
double tdis;
while (stack.size() != 0) {
node = stack.pop();
if (node.isLeaf) {
tdis = UtilZ.distance(input, node.value, std);
if (tdis < distance) {
distance = tdis;
nearest = node.value;
}
} else {
double minDistance = UtilZ.minDistance(input, node.max, node.min, std);
if (minDistance < distance) {
while (!node.isLeaf) {
if (input.get(node.partitionDimension) < node.partitionValue) {
stack.add(node.right);
node = node.left;
} else {
stack.push(node.left);
node = node.right;
}
}
tdis = UtilZ.distance(input, node.value, std);
if (tdis < distance) {
distance = tdis;
nearest = node.value;
}
}
}
}
return nearest;
}
public ArrayList<ArrayList<Double>> queryRecKNN(
ArrayList<Double> input, double distance, Stack<Node> stack, double[] std) {
ArrayList<ArrayList<Double>> nearest = new ArrayList<>();
Node node;
double tdis;
while (stack.size() != 0) {
node = stack.pop();
if (node.isLeaf) {
tdis = UtilZ.distance(input, node.value, std);
if (tdis < distance) {
distance = tdis;
nearest.add(node.value);
}
} else {
double minDistance = UtilZ.minDistance(input, node.max, node.min, std);
if (minDistance < distance) {
while (!node.isLeaf) {
if (input.get(node.partitionDimension) < node.partitionValue) {
stack.add(node.right);
node = node.left;
} else {
stack.push(node.left);
node = node.right;
}
}
tdis = UtilZ.distance(input, node.value, std);
if (tdis < distance) {
distance = tdis;
nearest.add(node.value);
}
}
}
}
return nearest;
}
public ArrayList<Double> findNearest(
ArrayList<Double> input, ArrayList<ArrayList<Double>> nearest, double[] std) {
double min_dis = Double.MAX_VALUE;
int min_index = 0;
for (int i = 0; i < nearest.size(); i++) {
double dis = UtilZ.distance(input, nearest.get(i), std);
if (dis < min_dis) {
min_dis = dis;
min_index = i;
}
}
ArrayList<Double> nt = nearest.get(min_index);
nearest.remove(min_index);
return nt;
}
public ArrayList<ArrayList<Double>> queryKNN(ArrayList<Double> input, int k, double[] std) {
ArrayList<ArrayList<Double>> kNearest = new ArrayList<>();
Node node = kdTree;
Stack<Node> stack = new Stack<>();
while (!node.isLeaf) {
if (input.get(node.partitionDimension) < node.partitionValue) {
stack.add(node.right);
node = node.left;
} else {
stack.push(node.left);
node = node.right;
}
}
double distance = UtilZ.distance(input, node.value, std);
ArrayList<ArrayList<Double>> nearest = queryRecKNN(input, distance, stack, std);
for (int i = 0; i < min(k, nearest.size()); i++) {
kNearest.add(findNearest(input, nearest, std));
}
if (kNearest.size() == 0) {
kNearest.add(node.value);
}
for (ArrayList<Double> doubles : kNearest) {
UtilZ.distance(doubles, input, std);
}
return kNearest;
}
private static class UtilZ {
static double variance(ArrayList<ArrayList<Double>> data, int dimension) {
double sum = 0d;
for (ArrayList<Double> d : data) {
sum += d.get(dimension);
}
double avg = sum / data.size();
double ans = 0d;
for (ArrayList<Double> d : data) {
double temp = d.get(dimension) - avg;
ans += temp * temp;
}
return ans / data.size();
}
static double median(ArrayList<ArrayList<Double>> data, int dimension) {
ArrayList<Double> d = new ArrayList<>();
for (ArrayList<Double> k : data) {
d.add(k.get(dimension));
}
Collections.sort(d);
int pos = d.size() / 2;
return d.get(pos);
}
static ArrayList<ArrayList<Double>> maxMin(ArrayList<ArrayList<Double>> data, int dimensions) {
ArrayList<ArrayList<Double>> mm = new ArrayList<>();
ArrayList<Double> min_v = new ArrayList<>();
ArrayList<Double> max_v = new ArrayList<>();
for (int i = 0; i < dimensions; i++) {
double min_temp = Double.MAX_VALUE;
double max_temp = Double.MIN_VALUE;
for (int j = 1; j < data.size(); j++) {
ArrayList<Double> d = data.get(j);
if (d.get(i) < min_temp) {
min_temp = d.get(i);
} else if (d.get(i) > max_temp) {
max_temp = d.get(i);
}
}
min_v.add(min_temp);
max_v.add(max_temp);
}
mm.add(min_v);
mm.add(max_v);
return mm;
}
static double distance(ArrayList<Double> a, ArrayList<Double> b, double[] std) {
double sum = 0d;
for (int i = 0; i < a.size(); i++) {
if (a.get(i) != null && b.get(i) != null)
sum += Math.pow((a.get(i) - b.get(i)) / std[i], 2);
}
sum = sqrt(sum);
return sum;
}
static double minDistance(
ArrayList<Double> a, ArrayList<Double> max, ArrayList<Double> min, double[] std) {
double sum = 0d;
for (int i = 0; i < a.size(); i++) {
if (a.get(i) > max.get(i)) sum += Math.pow((a.get(i) - max.get(i)) / std[i], 2);
else if (a.get(i) < min.get(i)) {
sum += Math.pow((min.get(i) - a.get(i)) / std[i], 2);
}
}
sum = sqrt(sum);
return sum;
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iotdb.commons.udf.utils;
import org.apache.iotdb.udf.api.access.Row;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
public class MasterRepairUtil {
private final ArrayList<ArrayList<Double>> td = new ArrayList<>();
private final ArrayList<ArrayList<Double>> tdCleaned = new ArrayList<>();
private final ArrayList<ArrayList<Double>> md = new ArrayList<>();
private final ArrayList<Long> tdTime = new ArrayList<>();
private final int columnCnt;
private long omega;
private Double eta;
private int k;
private double[] std;
private KDTreeUtil kdTreeUtil;
public MasterRepairUtil(int columnCnt, long omega, double eta, int k) {
this.columnCnt = columnCnt;
this.omega = omega;
this.eta = eta;
this.k = k;
}
public boolean isNullRow(Row row) {
boolean flag = true;
for (int i = 0; i < row.size(); i++) {
if (!row.isNull(i)) {
flag = false;
break;
}
}
return flag;
}
public void addRow(Row row) throws Exception {
ArrayList<Double> tt = new ArrayList<>(); // time-series tuple
boolean containsNotNull = false;
for (int i = 0; i < this.columnCnt; i++) {
if (!row.isNull(i)) {
containsNotNull = true;
BigDecimal bd = BigDecimal.valueOf(getValueAsDouble(row, i));
tt.add(bd.doubleValue());
} else {
tt.add(null);
}
}
if (containsNotNull) {
td.add(tt);
tdTime.add(row.getTime());
}
ArrayList<Double> mt = new ArrayList<>(); // master tuple
containsNotNull = false;
for (int i = this.columnCnt; i < row.size(); i++) {
if (!row.isNull(i)) {
containsNotNull = true;
BigDecimal bd = BigDecimal.valueOf(getValueAsDouble(row, i));
mt.add(bd.doubleValue());
} else {
mt.add(null);
}
}
if (containsNotNull) {
md.add(mt);
}
}
public static double getValueAsDouble(Row row, int index) throws Exception {
double ans;
try {
switch (row.getDataType(index)) {
case INT32:
ans = row.getInt(index);
break;
case INT64:
ans = row.getLong(index);
break;
case FLOAT:
ans = row.getFloat(index);
break;
case DOUBLE:
ans = row.getDouble(index);
break;
default:
throw new Exception("The value of the input time series is not numeric.\n");
}
} catch (IOException e) {
throw new Exception("Fail to get data type in row " + row.getTime(), e);
}
return ans;
}
public void buildKDTree() {
this.kdTreeUtil = KDTreeUtil.build(md, this.columnCnt);
}
public ArrayList<Double> getCleanResultColumn(int columnPos) {
ArrayList<Double> column = new ArrayList<>();
for (ArrayList<Double> tuple : this.tdCleaned) {
column.add(tuple.get(columnPos - 1));
}
return column;
}
public ArrayList<Long> getTime() {
return tdTime;
}
public double getTmDistance(ArrayList<Double> tTuple, ArrayList<Double> mTuple) {
double distance = 0d;
for (int pos = 0; pos < columnCnt; pos++) {
double temp = tTuple.get(pos) - mTuple.get(pos);
temp = temp / std[pos];
distance += temp * temp;
}
distance = Math.sqrt(distance);
return distance;
}
public ArrayList<Integer> calW(int i) {
ArrayList<Integer> Wi = new ArrayList<>();
for (int l = i - 1; l >= 0; l--) {
if (this.tdTime.get(i) <= this.tdTime.get(l) + omega) {
Wi.add(l);
}
}
return Wi;
}
public ArrayList<ArrayList<Double>> calC(int i, ArrayList<Integer> Wi) {
ArrayList<ArrayList<Double>> Ci = new ArrayList<>();
if (Wi.size() == 0) {
Ci.add(this.kdTreeUtil.query(this.td.get(i), std));
} else {
Ci.addAll(this.kdTreeUtil.queryKNN(this.td.get(i), k, std));
for (Integer integer : Wi) {
Ci.addAll(this.kdTreeUtil.queryKNN(this.tdCleaned.get(integer), k, std));
}
}
return Ci;
}
public void masterRepair() {
for (int i = 0; i < this.td.size(); i++) {
ArrayList<Double> tuple = this.td.get(i);
ArrayList<Integer> Wi = calW(i);
ArrayList<ArrayList<Double>> Ci = this.calC(i, Wi);
double minDis = Double.MAX_VALUE;
ArrayList<Double> repairTuple = new ArrayList<>();
for (ArrayList<Double> ci : Ci) {
boolean smooth = true;
for (Integer wi : Wi) {
ArrayList<Double> wis = tdCleaned.get(wi);
if (getTmDistance(ci, wis) > eta) {
smooth = false;
break;
}
}
if (smooth) {
double dis = getTmDistance(ci, tuple);
if (dis < minDis) {
minDis = dis;
repairTuple = ci;
}
}
}
this.tdCleaned.add(repairTuple);
}
}
public void setParameters() {
if (omega == -1) {
ArrayList<Long> intervals = getIntervals();
Collections.sort(intervals);
long interval = intervals.get(intervals.size() / 2);
omega = interval * 10;
}
if (Double.isNaN(eta)) {
ArrayList<Double> distanceList = new ArrayList<>();
for (int i = 1; i < this.td.size(); i++) {
for (int l = i - 1; l >= 0; l--) {
if (this.tdTime.get(i) <= this.tdTime.get(l) + omega) {
distanceList.add(getTmDistance(this.td.get(i), this.td.get(l)));
} else break;
}
}
Collections.sort(distanceList);
eta = distanceList.get((int) (distanceList.size() * 0.9973));
}
if (k == -1) {
for (int tempK = 2; tempK <= 5; tempK++) {
ArrayList<Double> distanceList = new ArrayList<>();
for (ArrayList<Double> tuple : this.td) {
ArrayList<ArrayList<Double>> neighbors = this.kdTreeUtil.queryKNN(tuple, tempK, std);
for (ArrayList<Double> neighbor : neighbors) {
distanceList.add(getTmDistance(tuple, neighbor));
}
}
Collections.sort(distanceList);
if (distanceList.get((int) (distanceList.size() * 0.9)) > eta) {
k = tempK;
break;
}
}
if (k == -1) {
k = Integer.min(5, this.md.size());
}
}
}
private double varianceImperative(double[] value) {
double average = 0.0;
int cnt = 0;
for (double p : value) {
if (!Double.isNaN(p)) {
cnt += 1;
average += p;
}
}
if (cnt == 0) {
return 0d;
}
average /= cnt;
double variance = 0.0;
for (double p : value) {
if (!Double.isNaN(p)) {
variance += (p - average) * (p - average);
}
}
return variance / cnt;
}
private double[] getColumn(int pos) {
double[] column = new double[this.td.size()];
for (int i = 0; i < this.td.size(); i++) {
column[i] = this.td.get(i).get(pos);
}
return column;
}
public void callStd() {
this.std = new double[this.columnCnt];
for (int i = 0; i < this.columnCnt; i++) {
std[i] = Math.sqrt(varianceImperative(getColumn(i)));
}
}
public void repair() {
fillNullValue();
buildKDTree();
callStd();
setParameters();
masterRepair();
}
public ArrayList<Long> getIntervals() {
ArrayList<Long> intervals = new ArrayList<>();
for (int i = 1; i < this.tdTime.size(); i++) {
intervals.add(this.tdTime.get(i) - this.tdTime.get(i - 1));
}
return intervals;
}
public void fillNullValue() {
for (int i = 0; i < columnCnt; i++) {
double temp = this.td.get(0).get(i);
for (ArrayList<Double> arrayList : this.td) {
if (arrayList.get(i) == null) {
arrayList.set(i, temp);
} else {
temp = arrayList.get(i);
}
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册