提交 380fa6e4 编写于 作者: A Aljoscha Krettek

Add scala code to mainline

上级 09012c0e
......@@ -8,6 +8,7 @@ target
.project
.settings
.cache
.metadata/*
*.class
filter.properties
tmp/*
......
......@@ -91,6 +91,13 @@ Get some test data:
To contribute back to the project or develop your own jobs for Stratosphere, you need a working development environment. We use Eclipse and IntelliJ for development. Here we focus on Eclipse.
If you want to work on the scala code you will need the following plugins:
* scala-ide: http://download.scala-ide.org/sdk/e38/scala210/stable/site
* m2eclipse-scala: http://alchim31.free.fr/m2e-scala/update-site
* build-helper-maven-plugin: https://repository.sonatype.org/content/repositories/forge-sites/m2e-extras/0.15.0/N/0.15.0.201206251206/
When you don't have the plugins your project will have build errors, you can just close the scala projects and ignore them.
Import the Stratosphere source code using Maven's Import tool:
* Select "Import" from the "File"-menu.
* Expand "Maven" node, select "Existing Maven Projects", and click "next" button
......
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<artifactId>pact-scala</artifactId>
<groupId>eu.stratosphere</groupId>
<version>0.4-ozone-SNAPSHOT</version>
<relativePath>..</relativePath>
</parent>
<artifactId>pact-scala-core</artifactId>
<name>pact-scala-core</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>eu.stratosphere</groupId>
<artifactId>pact-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>eu.stratosphere</groupId>
<artifactId>pact-compiler</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<reporting>
<plugins>
</plugins>
</reporting>
<build>
<plugins>
</plugins>
</build>
</project>
/***********************************************************************************************************************
*
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package eu.stratosphere.scala.contracts;
import java.lang.annotation.Annotation;
import java.util.Arrays;
import eu.stratosphere.pact.common.contract.ReduceContract;
import eu.stratosphere.pact.common.stubs.StubAnnotation;
public class Annotations {
public static Annotation getConstantFields(int[] fields) {
return new ConstantFields(fields);
}
public static Annotation getConstantFieldsFirst(int[] fields) {
return new ConstantFieldsFirst(fields);
}
public static Annotation getConstantFieldsSecond(int[] fields) {
return new ConstantFieldsSecond(fields);
}
public static Annotation getConstantFieldsExcept(int[] fields) {
return new ConstantFieldsExcept(fields);
}
public static Annotation getConstantFieldsFirstExcept(int[] fields) {
return new ConstantFieldsFirstExcept(fields);
}
public static Annotation getConstantFieldsSecondExcept(int[] fields) {
return new ConstantFieldsSecondExcept(fields);
}
public static Annotation getCombinable() {
return new Combinable();
}
private static abstract class Fields<T extends Annotation> implements Annotation {
private final Class<T> clazz;
private final int[] fields;
public Fields(Class<T> clazz, int[] fields) {
this.clazz = clazz;
this.fields = fields;
}
public int[] value() {
return fields;
}
@Override
public Class<? extends Annotation> annotationType() {
return clazz;
}
@Override
@SuppressWarnings("unchecked")
public boolean equals(Object obj) {
if (obj == null || !annotationType().isAssignableFrom(obj.getClass()))
return false;
if (!annotationType().equals(((Annotation) obj).annotationType()))
return false;
int[] otherFields = getOtherFields((T) obj);
return Arrays.equals(fields, otherFields);
}
protected abstract int[] getOtherFields(T other);
@Override
public int hashCode() {
return 0xf16cd51b ^ Arrays.hashCode(fields);
}
}
@SuppressWarnings("all")
private static class ConstantFields extends Fields<StubAnnotation.ConstantFields> implements
StubAnnotation.ConstantFields {
public ConstantFields(int[] fields) {
super(StubAnnotation.ConstantFields.class, fields);
}
@Override
protected int[] getOtherFields(StubAnnotation.ConstantFields other) {
return other.value();
}
}
@SuppressWarnings("all")
private static class ConstantFieldsFirst extends Fields<StubAnnotation.ConstantFieldsFirst> implements
StubAnnotation.ConstantFieldsFirst {
public ConstantFieldsFirst(int[] fields) {
super(StubAnnotation.ConstantFieldsFirst.class, fields);
}
@Override
protected int[] getOtherFields(StubAnnotation.ConstantFieldsFirst other) {
return other.value();
}
}
@SuppressWarnings("all")
private static class ConstantFieldsSecond extends Fields<StubAnnotation.ConstantFieldsSecond> implements
StubAnnotation.ConstantFieldsSecond {
public ConstantFieldsSecond(int[] fields) {
super(StubAnnotation.ConstantFieldsSecond.class, fields);
}
@Override
protected int[] getOtherFields(StubAnnotation.ConstantFieldsSecond other) {
return other.value();
}
}
@SuppressWarnings("all")
private static class ConstantFieldsExcept extends Fields<StubAnnotation.ConstantFieldsExcept> implements
StubAnnotation.ConstantFieldsExcept {
public ConstantFieldsExcept(int[] fields) {
super(StubAnnotation.ConstantFieldsExcept.class, fields);
}
@Override
protected int[] getOtherFields(StubAnnotation.ConstantFieldsExcept other) {
return other.value();
}
}
@SuppressWarnings("all")
private static class ConstantFieldsFirstExcept extends Fields<StubAnnotation.ConstantFieldsFirstExcept> implements
StubAnnotation.ConstantFieldsFirstExcept {
public ConstantFieldsFirstExcept(int[] fields) {
super(StubAnnotation.ConstantFieldsFirstExcept.class, fields);
}
@Override
protected int[] getOtherFields(StubAnnotation.ConstantFieldsFirstExcept other) {
return other.value();
}
}
@SuppressWarnings("all")
private static class ConstantFieldsSecondExcept extends Fields<StubAnnotation.ConstantFieldsSecondExcept> implements
StubAnnotation.ConstantFieldsSecondExcept {
public ConstantFieldsSecondExcept(int[] fields) {
super(StubAnnotation.ConstantFieldsSecondExcept.class, fields);
}
@Override
protected int[] getOtherFields(StubAnnotation.ConstantFieldsSecondExcept other) {
return other.value();
}
}
@SuppressWarnings("all")
private static class Combinable implements Annotation, ReduceContract.Combinable {
public Combinable() {
}
@Override
public Class<? extends Annotation> annotationType() {
return ReduceContract.Combinable.class;
}
@Override
public boolean equals(Object obj) {
if (obj == null || !annotationType().isAssignableFrom(obj.getClass()))
return false;
if (!annotationType().equals(((Annotation) obj).annotationType()))
return false;
return true;
}
@Override
public int hashCode() {
return 0;
}
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import language.experimental.macros
import scala.util.DynamicVariable
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.contracts._
import eu.stratosphere.pact.common.util.{ FieldSet => PactFieldSet }
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.scala.codegen.MacroContextHolder
import scala.reflect.macros.Context
case class KeyCardinality(key: FieldSelector, isUnique: Boolean, distinctCount: Option[Long], avgNumRecords: Option[Float]) {
private class RefreshableFieldSet extends PactFieldSet {
def refresh(indexes: Set[Int]) = {
this.collection.clear()
for (index <- indexes)
this.add(index)
}
}
@transient private var pactFieldSets = collection.mutable.Map[Contract with ScalaContract[_], RefreshableFieldSet]()
def getPactFieldSet(contract: Contract with ScalaContract[_]): PactFieldSet = {
if (pactFieldSets == null)
pactFieldSets = collection.mutable.Map[Contract with ScalaContract[_], RefreshableFieldSet]()
val keyCopy = key.copy
contract.getUDF.attachOutputsToInputs(keyCopy.inputFields)
val keySet = keyCopy.selectedFields.toIndexSet
val fieldSet = pactFieldSets.getOrElseUpdate(contract, new RefreshableFieldSet())
fieldSet.refresh(keySet)
fieldSet
}
}
trait OutputHintable[Out] { this: DataStream[Out] =>
def getContract = contract
private var _cardinalities: List[KeyCardinality] = List[KeyCardinality]()
def addCardinality(card: KeyCardinality) { _cardinalities = card :: _cardinalities }
def degreeOfParallelism = contract.getDegreeOfParallelism()
def degreeOfParallelism_=(value: Int) = contract.setDegreeOfParallelism(value)
def degreeOfParallelism(value: Int): this.type = { contract.setDegreeOfParallelism(value); this }
def avgBytesPerRecord = contract.getCompilerHints().getAvgBytesPerRecord()
def avgBytesPerRecord_=(value: Float) = contract.getCompilerHints().setAvgBytesPerRecord(value)
def avgBytesPerRecord(value: Float): this.type = { contract.getCompilerHints().setAvgBytesPerRecord(value); this }
def avgRecordsEmittedPerCall = contract.getCompilerHints().getAvgRecordsEmittedPerStubCall()
def avgRecordsEmittedPerCall_=(value: Float) = contract.getCompilerHints().setAvgRecordsEmittedPerStubCall(value)
def avgRecordsEmittedPerCall(value: Float): this.type = { contract.getCompilerHints().setAvgRecordsEmittedPerStubCall(value); this }
def uniqueKey[Key](fields: Out => Key) = macro OutputHintableMacros.uniqueKey[Out, Key]
def uniqueKey[Key](fields: Out => Key, distinctCount: Long) = macro OutputHintableMacros.uniqueKeyWithDistinctCount[Out, Key]
def cardinality[Key](fields: Out => Key) = macro OutputHintableMacros.cardinality[Out, Key]
def cardinality[Key](fields: Out => Key, distinctCount: Long) = macro OutputHintableMacros.cardinalityWithDistinctCount[Out, Key]
def cardinality[Key](fields: Out => Key, avgNumRecords: Float) = macro OutputHintableMacros.cardinalityWithAvgNumRecords[Out, Key]
def cardinality[Key](fields: Out => Key, distinctCount: Long, avgNumRecords: Float) = macro OutputHintableMacros.cardinalityWithAll[Out, Key]
def applyHints(contract: Contract with ScalaContract[_]): Unit = {
val hints = contract.getCompilerHints
if (hints.getUniqueFields != null)
hints.getUniqueFields.clear()
hints.getDistinctCounts.clear()
hints.getAvgNumRecordsPerDistinctFields.clear()
_cardinalities.foreach { card =>
val fieldSet = card.getPactFieldSet(contract)
if (card.isUnique) {
hints.addUniqueField(fieldSet)
}
card.distinctCount.foreach(hints.setDistinctCount(fieldSet, _))
card.avgNumRecords.foreach(hints.setAvgNumRecordsPerDistinctFields(fieldSet, _))
}
}
}
object OutputHintableMacros {
def uniqueKey[Out: c.WeakTypeTag, Key: c.WeakTypeTag](c: Context { type PrefixType = OutputHintable[Out] })(fields: c.Expr[Out => Key]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedKeySelector = slave.getSelector(fields)
val result = reify {
val contract = c.prefix.splice.getContract
val hints = contract.getCompilerHints
val keySelection = generatedKeySelector.splice
val key = new FieldSelector(c.prefix.splice.getContract.getUDF.outputUDT, keySelection)
val card = KeyCardinality(key, true, None, None)
c.prefix.splice.addCardinality(card)
}
return result
}
def uniqueKeyWithDistinctCount[Out: c.WeakTypeTag, Key: c.WeakTypeTag](c: Context { type PrefixType = OutputHintable[Out] })(fields: c.Expr[Out => Key], distinctCount: c.Expr[Long]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedKeySelector = slave.getSelector(fields)
val result = reify {
val contract = c.prefix.splice.getContract
val hints = contract.getCompilerHints
val keySelection = generatedKeySelector.splice
val key = new FieldSelector(c.prefix.splice.getContract.getUDF.outputUDT, keySelection)
val card = KeyCardinality(key, true, Some(distinctCount.splice), None)
c.prefix.splice.addCardinality(card)
}
return result
}
def cardinality[Out: c.WeakTypeTag, Key: c.WeakTypeTag](c: Context { type PrefixType = OutputHintable[Out] })(fields: c.Expr[Out => Key]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedKeySelector = slave.getSelector(fields)
val result = reify {
val contract = c.prefix.splice.getContract
val hints = contract.getCompilerHints
val keySelection = generatedKeySelector.splice
val key = new FieldSelector(c.prefix.splice.getContract.getUDF.outputUDT, keySelection)
val card = KeyCardinality(key, false, None, None)
c.prefix.splice.addCardinality(card)
}
return result
}
def cardinalityWithDistinctCount[Out: c.WeakTypeTag, Key: c.WeakTypeTag](c: Context { type PrefixType = OutputHintable[Out] })(fields: c.Expr[Out => Key], distinctCount: c.Expr[Long]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedKeySelector = slave.getSelector(fields)
val result = reify {
val contract = c.prefix.splice.getContract
val hints = contract.getCompilerHints
val keySelection = generatedKeySelector.splice
val key = new FieldSelector(c.prefix.splice.getContract.getUDF.outputUDT, keySelection)
val card = KeyCardinality(key, false, Some(distinctCount.splice), None)
c.prefix.splice.addCardinality(card)
}
return result
}
def cardinalityWithAvgNumRecords[Out: c.WeakTypeTag, Key: c.WeakTypeTag](c: Context { type PrefixType = OutputHintable[Out] })(fields: c.Expr[Out => Key], avgNumRecords: c.Expr[Float]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedKeySelector = slave.getSelector(fields)
val result = reify {
val contract = c.prefix.splice.getContract
val hints = contract.getCompilerHints
val keySelection = generatedKeySelector.splice
val key = new FieldSelector(c.prefix.splice.getContract.getUDF.outputUDT, keySelection)
val card = KeyCardinality(key, false, None, Some(avgNumRecords.splice))
c.prefix.splice.addCardinality(card)
}
return result
}
def cardinalityWithAll[Out: c.WeakTypeTag, Key: c.WeakTypeTag](c: Context { type PrefixType = OutputHintable[Out] })(fields: c.Expr[Out => Key], distinctCount: c.Expr[Long], avgNumRecords: c.Expr[Float]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedKeySelector = slave.getSelector(fields)
val result = reify {
val contract = c.prefix.splice.getContract
val hints = contract.getCompilerHints
val keySelection = generatedKeySelector.splice
val key = new FieldSelector(c.prefix.splice.getContract.getUDF.outputUDT, keySelection)
val card = KeyCardinality(key, false, Some(distinctCount.splice), Some(avgNumRecords.splice))
c.prefix.splice.addCardinality(card)
}
return result
}
}
trait InputHintable[In, Out] { this: DataStream[Out] =>
def markUnread: Int => Unit
def markCopied: (Int, Int) => Unit
def getInputUDT: UDT[In]
def getOutputUDT: UDT[Out]
def neglects[Fields](fields: In => Fields): Unit = macro InputHintableMacros.neglects[In, Out, Fields]
def observes[Fields](fields: In => Fields): Unit = macro InputHintableMacros.observes[In, Out, Fields]
def preserves[Fields](from: In => Fields, to: Out => Fields) = macro InputHintableMacros.preserves[In, Out, Fields]
}
object InputHintable {
private val enabled = new DynamicVariable[Boolean](true)
def withEnabled[T](isEnabled: Boolean)(thunk: => T): T = enabled.withValue(isEnabled) { thunk }
}
object InputHintableMacros {
def neglects[In: c.WeakTypeTag, Out: c.WeakTypeTag, Fields: c.WeakTypeTag](c: Context { type PrefixType = InputHintable[In, Out] })(fields: c.Expr[In => Fields]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedFieldSelector = slave.getSelector(fields)
val result = reify {
val fieldSelection = generatedFieldSelector.splice
val fieldSelector = new FieldSelector(c.prefix.splice.getInputUDT, fieldSelection)
val unreadFields = fieldSelector.selectedFields.map(_.localPos).toSet
unreadFields.foreach(c.prefix.splice.markUnread(_))
}
return result
}
def observes[In: c.WeakTypeTag, Out: c.WeakTypeTag, Fields: c.WeakTypeTag](c: Context { type PrefixType = InputHintable[In, Out] })(fields: c.Expr[In => Fields]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedFieldSelector = slave.getSelector(fields)
val result = reify {
val fieldSelection = generatedFieldSelector.splice
val fieldSelector = new FieldSelector(c.prefix.splice.getInputUDT, fieldSelection)
val fieldSet = fieldSelector.selectedFields.map(_.localPos).toSet
val unreadFields = fieldSelector.inputFields.map(_.localPos).toSet.diff(fieldSet)
unreadFields.foreach(c.prefix.splice.markUnread(_))
}
return result
}
def preserves[In: c.WeakTypeTag, Out: c.WeakTypeTag, Fields: c.WeakTypeTag](c: Context { type PrefixType = InputHintable[In, Out] })(from: c.Expr[In => Fields], to: c.Expr[Out => Fields]): c.Expr[Unit] = {
import c.universe._
val slave = MacroContextHolder.newMacroHelper(c)
val generatedFromFieldSelector = slave.getSelector(from)
val generatedToFieldSelector = slave.getSelector(to)
val result = reify {
val fromSelection = generatedFromFieldSelector.splice
val fromSelector = new FieldSelector(c.prefix.splice.getInputUDT, fromSelection)
val toSelection = generatedToFieldSelector.splice
val toSelector = new FieldSelector(c.prefix.splice.getOutputUDT, toSelection)
val pairs = fromSelector.selectedFields.map(_.localPos).zip(toSelector.selectedFields.map(_.localPos))
pairs.foreach(c.prefix.splice.markCopied.tupled)
}
return result
}
}
trait OneInputHintable[In, Out] extends InputHintable[In, Out] with OutputHintable[Out] { this: DataStream[Out] =>
override def markUnread = contract.getUDF.asInstanceOf[UDF1[In, Out]].markInputFieldUnread _
override def markCopied = contract.getUDF.asInstanceOf[UDF1[In, Out]].markFieldCopied _
override def getInputUDT = contract.getUDF.asInstanceOf[UDF1[In, Out]].inputUDT
override def getOutputUDT = contract.getUDF.asInstanceOf[UDF1[In, Out]].outputUDT
}
trait TwoInputHintable[LeftIn, RightIn, Out] extends OutputHintable[Out] { this: DataStream[Out] =>
val left = new DataStream[Out](contract) with OneInputHintable[LeftIn, Out] {
override def markUnread = { pos: Int => contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].markInputFieldUnread(Left(pos))}
override def markCopied = { (from: Int, to: Int) => contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].markFieldCopied(Left(from), to)}
override def getInputUDT = contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].leftInputUDT
override def getOutputUDT = contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].outputUDT
}
val right = new DataStream[Out](contract) with OneInputHintable[RightIn, Out] {
override def markUnread = { pos: Int => contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].markInputFieldUnread(Right(pos))}
override def markCopied = { (from: Int, to: Int) => contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].markFieldCopied(Right(from), to)}
override def getInputUDT = contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].rightInputUDT
override def getOutputUDT = contract.getUDF.asInstanceOf[UDF2[LeftIn, RightIn, Out]].outputUDT
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import java.net.URI
import eu.stratosphere.scala.analysis._
import eu.stratosphere.pact.common.contract.FileDataSink
import eu.stratosphere.pact.generic.io.OutputFormat
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.nephele.configuration.Configuration
import eu.stratosphere.pact.generic.io.FileOutputFormat
import eu.stratosphere.pact.common.contract.GenericDataSink
object DataSinkOperator {
def write[In](input: DataStream[In], url: String, format: DataSinkFormat[In]): ScalaSink[In] = {
val uri = getUri(url)
val ret = uri.getScheme match {
case "file" | "hdfs" => new FileDataSink(format.format.asInstanceOf[FileOutputFormat[_]], uri.toString, input.contract)
with OneInputScalaContract[In, Nothing] {
def getUDF = format.getUDF
override def persistConfiguration() = format.persistConfiguration(this.getParameters())
}
}
new ScalaSink(ret)
}
private def getUri(url: String) = {
val uri = new URI(url)
if (uri.getScheme == null)
new URI("file://" + url)
else
uri
}
}
class ScalaSink[In] private[scala] (private[scala] val sink: GenericDataSink)
abstract class DataSinkFormat[In](val format: OutputFormat[_], val udt: UDT[In]) {
def getUDF: UDF1[In, Nothing]
def persistConfiguration(config: Configuration) = {}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import java.net.URI
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.operators.stubs._
import eu.stratosphere.pact.common.`type`.base._
import eu.stratosphere.pact.common.`type`.base.parser._
import eu.stratosphere.pact.generic.io.InputFormat
import eu.stratosphere.pact.common.contract.GenericDataSource
import eu.stratosphere.pact.common.contract.FileDataSource
import eu.stratosphere.nephele.configuration.Configuration
import eu.stratosphere.pact.generic.io.FileInputFormat
import eu.stratosphere.pact.generic.io.GenericInputFormat
import eu.stratosphere.scala.operators.TextDataSourceFormat
object DataSource {
def apply[Out](url: String, format: DataSourceFormat[Out]): DataStream[Out] with OutputHintable[Out] = {
val uri = getUri(url)
val ret = uri.getScheme match {
case "file" | "hdfs" => new FileDataSource(format.format.asInstanceOf[FileInputFormat[_]], uri.toString)
with ScalaContract[Out] {
override def getUDF = format.getUDF
override def persistConfiguration() = format.persistConfiguration(this.getParameters())
}
case "ext" => new GenericDataSource[GenericInputFormat[_]](format.format.asInstanceOf[GenericInputFormat[_]], uri.toString)
with ScalaContract[Out] {
override def getUDF = format.getUDF
override def persistConfiguration() = format.persistConfiguration(this.getParameters())
}
}
new DataStream[Out](ret) with OutputHintable[Out] {}
}
private def getUri(url: String) = {
val uri = new URI(url)
if (uri.getScheme == null)
new URI("file://" + url)
else
uri
}
}
abstract class DataSourceFormat[Out](val format: InputFormat[_, _], val udt: UDT[Out]) {
def getUDF: UDF0[Out]
def persistConfiguration(config: Configuration) = {}
}
// convenience text file to look good in word count example :D
object TextFile {
def apply(url: String): DataStream[String] with OutputHintable[String] = DataSource(url, TextDataSourceFormat())
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import language.experimental.macros
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.scala.operators.CoGroupDataStream
import eu.stratosphere.scala.operators.CrossDataStream
import eu.stratosphere.scala.operators.JoinDataStream
import eu.stratosphere.scala.operators.MapMacros
import eu.stratosphere.scala.operators.GroupByDataStream
import eu.stratosphere.scala.operators.ReduceMacros
import eu.stratosphere.scala.operators.UnionMacros
import eu.stratosphere.scala.operators.IterateMacros
import eu.stratosphere.scala.operators.WorksetIterateMacros
class DataStream[T] (val contract: Contract with ScalaContract[T]) {
def cogroup[RightIn](rightInput: DataStream[RightIn]) = new CoGroupDataStream[T, RightIn](this, rightInput)
def cross[RightIn](rightInput: DataStream[RightIn]) = new CrossDataStream[T, RightIn](this, rightInput)
def join[RightIn](rightInput: DataStream[RightIn]) = new JoinDataStream[T, RightIn](this, rightInput)
def map[Out](fun: T => Out): DataStream[Out] = macro MapMacros.map[T, Out]
def flatMap[Out](fun: T => Iterator[Out]) = macro MapMacros.flatMap[T, Out]
def filter(fun: T => Boolean): DataStream[T] = macro MapMacros.filter[T]
// reduce
def groupBy[Key](keyFun: T => Key) = macro ReduceMacros.groupByImpl[T, Key]
def union(secondInput: DataStream[T]): DataStream[T] = macro UnionMacros.impl[T]
def iterateWithDelta[DeltaItem](stepFunction: DataStream[T] => (DataStream[T], DataStream[DeltaItem])) = macro IterateMacros.iterateWithDelta[T, DeltaItem]
def iterate(n: Int, stepFunction: DataStream[T] => DataStream[T]): DataStream[T] = macro IterateMacros.iterate[T]
def iterateWithWorkset[SolutionKey, WorksetItem](workset: DataStream[WorksetItem], solutionSetKey: T => SolutionKey, stepFunction: (DataStream[T], DataStream[WorksetItem]) => (DataStream[T], DataStream[WorksetItem])) = macro WorksetIterateMacros.iterateWithWorkset[T, SolutionKey, WorksetItem]
def write(url: String, format: DataSinkFormat[T]) = DataSinkOperator.write(this, url, format)
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import java.lang.annotation.Annotation
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.scala.analysis.UDF
import eu.stratosphere.scala.analysis.UDF1
import eu.stratosphere.scala.analysis.UDF2
import eu.stratosphere.scala.analysis.FieldSelector
import eu.stratosphere.pact.compiler.plan.OptimizerNode
import eu.stratosphere.pact.generic.contract.AbstractPact
import eu.stratosphere.scala.analysis.UDF0
import eu.stratosphere.pact.common.stubs.StubAnnotation
trait ScalaContract[T] { this: Contract =>
def getUDF(): UDF[T]
def getKeys: Seq[FieldSelector] = Seq()
def persistConfiguration(): Unit = {}
var persistHints: () => Unit = { () => }
def persistConfiguration(optimizerNode: Option[OptimizerNode]): Unit = {
ScalaContract.this match {
case contract: AbstractPact[_] => {
for ((key, inputNum) <- getKeys.zipWithIndex) {
val source = key.selectedFields.toSerializerIndexArray
val target = optimizerNode map { _.getRemappedKeys(inputNum) } getOrElse { contract.getKeyColumns(inputNum) }
assert(source.length == target.length, "Attempt to write " + source.length + " key indexes to an array of size " + target.length)
System.arraycopy(source, 0, target, 0, source.length)
}
}
case _ if getKeys.size > 0 => throw new UnsupportedOperationException("Attempted to set keys on a contract that doesn't support them")
case _ =>
}
persistHints()
persistConfiguration()
}
protected def annotations: Seq[Annotation] = Seq()
override def getUserCodeAnnotation[A <: Annotation](annotationClass: Class[A]): A = {
val res = annotations find { _.annotationType().equals(annotationClass) } map { _.asInstanceOf[A] } getOrElse null.asInstanceOf[A]
// println("returning ANOOT: " + res + " FOR: " + annotationClass.toString)
// res match {
// case r : StubAnnotation.ConstantFieldsFirst => println("CONSTANT FIELDS FIRST: " + r.value().mkString(","))
// case r : StubAnnotation.ConstantFieldsSecond => println("CONSTANT FIELDS SECOND: " + r.value().mkString(","))
// case _ =>
// }
res
}
}
trait NoOpScalaContract[In, Out] extends ScalaContract[Out] { this: Contract =>
}
trait UnionScalaContract[In] extends NoOpScalaContract[In, In] { this: Contract =>
override def getUDF(): UDF1[In, In]
}
trait HigherOrderScalaContract[T] extends ScalaContract[T] { this: Contract =>
override def getUDF(): UDF0[T]
}
trait BulkIterationScalaContract[T] extends HigherOrderScalaContract[T] { this: Contract =>
}
trait WorksetIterationScalaContract[T] extends HigherOrderScalaContract[T] { this: Contract =>
val key: FieldSelector
}
trait OneInputScalaContract[In, Out] extends ScalaContract[Out] { this: Contract =>
override def getUDF(): UDF1[In, Out]
}
trait TwoInputScalaContract[In1, In2, Out] extends ScalaContract[Out] { this: Contract =>
override def getUDF(): UDF2[In1, In2, Out]
}
trait OneInputKeyedScalaContract[In, Out] extends OneInputScalaContract[In, Out] { this: Contract =>
val key: FieldSelector
override def getKeys = Seq(key)
}
trait TwoInputKeyedScalaContract[LeftIn, RightIn, Out] extends TwoInputScalaContract[LeftIn, RightIn, Out] { this: Contract =>
val leftKey: FieldSelector
val rightKey: FieldSelector
override def getKeys = Seq(leftKey, rightKey)
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import scala.collection.JavaConversions.asJavaCollection
import eu.stratosphere.pact.common.plan.Plan
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
import eu.stratosphere.pact.compiler.postpass.GenericPactRecordPostPass
import java.util.Calendar
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.scala.analysis.NewGlobalSchemaGenerator
import eu.stratosphere.scala.analysis.postPass.GlobalSchemaOptimizer
import eu.stratosphere.pact.common.plan.PlanAssembler
import eu.stratosphere.pact.common.plan.PlanAssemblerDescription
class ScalaPlan(scalaSinks: Seq[ScalaSink[_]], scalaJobName: String = "PACT SCALA Job at " + Calendar.getInstance().getTime()) extends Plan(asJavaCollection(scalaSinks map { _.sink }), scalaJobName) {
val pactSinks = scalaSinks map { _.sink.asInstanceOf[Contract with ScalaContract[_]] }
NewGlobalSchemaGenerator.initGlobalSchema(pactSinks)
override def getPostPassClassName() = "eu.stratosphere.scala.ScalaPostPass";
}
case class Args(argsMap: Map[String, String], defaultParallelism: Int, schemaHints: Boolean, schemaCompaction: Boolean) {
def apply(key: String): String = argsMap.getOrElse(key, key)
def apply(key: String, default: => String) = argsMap.getOrElse(key, default)
}
object Args {
def parse(args: Seq[String]): Args = {
var argsMap = Map[String, String]()
var defaultParallelism = 1
var schemaHints = true
var schemaCompaction = true
val ParamName = "-(.+)".r
def parse(args: Seq[String]): Unit = args match {
case Seq("-subtasks", value, rest @ _*) => { defaultParallelism = value.toInt; parse(rest) }
case Seq("-nohints", rest @ _*) => { schemaHints = false; parse(rest) }
case Seq("-nocompact", rest @ _*) => { schemaCompaction = false; parse(rest) }
case Seq(ParamName(name), value, rest @ _*) => { argsMap = argsMap.updated(name, value); parse(rest) }
case Seq() =>
}
parse(args)
Args(argsMap, defaultParallelism, schemaHints, schemaCompaction)
}
}
abstract class ScalaPlanAssembler extends PlanAssembler {
def getScalaPlan(args: Args): ScalaPlan
override def getPlan(args: String*): Plan = {
val scalaArgs = Args.parse(args.toSeq)
getScalaPlan(scalaArgs)
}
}
class Pact4sInstantiationException(cause: Throwable) extends Exception("Could not instantiate program.", cause)
class ScalaPostPass extends GenericPactRecordPostPass with GlobalSchemaOptimizer {
override def postPass(plan: OptimizedPlan): Unit = {
optimizeSchema(plan, false)
super.postPass(plan)
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.collectionAsScalaIterable
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.contracts._
import eu.stratosphere.pact.compiler.plan._
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.pact.common.contract.GenericDataSink
import eu.stratosphere.pact.generic.contract.DualInputContract
import eu.stratosphere.pact.generic.contract.SingleInputContract
import eu.stratosphere.pact.common.contract.MapContract
import eu.stratosphere.scala.ScalaContract
import eu.stratosphere.pact.common.contract.GenericDataSource
import eu.stratosphere.scala.OneInputScalaContract
import eu.stratosphere.scala.OneInputKeyedScalaContract
import eu.stratosphere.pact.common.contract.ReduceContract
import eu.stratosphere.pact.common.contract.CrossContract
import eu.stratosphere.scala.TwoInputScalaContract
import eu.stratosphere.pact.common.contract.MatchContract
import eu.stratosphere.scala.TwoInputKeyedScalaContract
import eu.stratosphere.pact.common.contract.CoGroupContract
import eu.stratosphere.pact.generic.contract.WorksetIteration
import eu.stratosphere.pact.generic.contract.BulkIteration
import eu.stratosphere.scala.WorksetIterationScalaContract
import eu.stratosphere.scala.BulkIterationScalaContract
import eu.stratosphere.scala.UnionScalaContract
object Extractors {
object DataSinkNode {
def unapply(node: Contract): Option[(UDF1[_, _], List[Contract])] = node match {
case contract: GenericDataSink with ScalaContract[_] => {
Some((contract.getUDF.asInstanceOf[UDF1[_, _]], node.asInstanceOf[GenericDataSink].getInputs().toList))
}
case _ => None
}
}
object DataSourceNode {
def unapply(node: Contract): Option[(UDF0[_])] = node match {
case contract: GenericDataSource[_] with ScalaContract[_] => Some(contract.getUDF.asInstanceOf[UDF0[_]])
case _ => None
}
}
object CoGroupNode {
def unapply(node: Contract): Option[(UDF2[_, _, _], FieldSelector, FieldSelector, List[Contract], List[Contract])] = node match {
case contract: CoGroupContract with TwoInputKeyedScalaContract[_, _, _] => Some((contract.getUDF, contract.leftKey, contract.rightKey, contract.asInstanceOf[DualInputContract[_]].getFirstInputs().toList, contract.asInstanceOf[DualInputContract[_]].getSecondInputs().toList))
case _ => None
}
}
object CrossNode {
def unapply(node: Contract): Option[(UDF2[_, _, _], List[Contract], List[Contract])] = node match {
case contract: CrossContract with TwoInputScalaContract[_, _, _] => Some((contract.getUDF, contract.asInstanceOf[DualInputContract[_]].getFirstInputs().toList, contract.asInstanceOf[DualInputContract[_]].getSecondInputs().toList))
case _ => None
}
}
object JoinNode {
def unapply(node: Contract): Option[(UDF2[_, _, _], FieldSelector, FieldSelector, List[Contract], List[Contract])] = node match {
case contract: MatchContract with TwoInputKeyedScalaContract[ _, _, _] => Some((contract.getUDF, contract.leftKey, contract.rightKey, contract.asInstanceOf[DualInputContract[_]].getFirstInputs().toList, contract.asInstanceOf[DualInputContract[_]].getSecondInputs().toList))
case _ => None
}
}
object MapNode {
def unapply(node: Contract): Option[(UDF1[_, _], List[Contract])] = node match {
case contract: MapContract with OneInputScalaContract[_, _] => Some((contract.getUDF, contract.asInstanceOf[SingleInputContract[_]].getInputs().toList))
case _ => None
}
}
object UnionNode {
def unapply(node: Contract): Option[(UDF1[_, _], List[Contract])] = node match {
case contract: MapContract with UnionScalaContract[_] => Some((contract.getUDF, contract.asInstanceOf[SingleInputContract[_]].getInputs().toList))
case _ => None
}
}
object ReduceNode {
def unapply(node: Contract): Option[(UDF1[_, _], FieldSelector, List[Contract])] = node match {
case contract: ReduceContract with OneInputKeyedScalaContract[_, _] => Some((contract.getUDF, contract.key, contract.asInstanceOf[SingleInputContract[_]].getInputs().toList))
case _ => None
}
}
object WorksetIterationNode {
def unapply(node: Contract): Option[(UDF0[_], FieldSelector, List[Contract], List[Contract])] = node match {
case contract: WorksetIteration with WorksetIterationScalaContract[_] => Some((contract.getUDF, contract.key, contract.asInstanceOf[DualInputContract[_]].getFirstInputs().toList, contract.asInstanceOf[DualInputContract[_]].getSecondInputs().toList))
case _ => None
}
}
object BulkIterationNode {
def unapply(node: Contract): Option[(UDF0[_], List[Contract])] = node match {
case contract: BulkIteration with BulkIterationScalaContract[_] => Some((contract.getUDF, contract.asInstanceOf[SingleInputContract[_]].getInputs().toList))
case _ => None
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import scala.language.experimental.macros
import scala.language.implicitConversions
/**
* Instances of this class are typically created by the field selector macros fieldSelectorImpl
* and keySelectorImpl in {@link FieldSelectorMacros}.
*
* In addition to the language restrictions applied to the lambda expression, the selected fields
* must also be top-level. Nested fields (such as a list element or an inner instance of a
* recursive type) are not allowed.
*
* @param selection The selected fields
*/
class FieldSelector(udt: UDT[_], selection: List[Int]) extends Serializable {
val inputFields = FieldSet.newInputSet(udt.numFields)
val selectionIndices = udt.getSelectionIndices(selection)
val selectedFields = inputFields.select(selectionIndices)
for (field <- inputFields.diff(selectedFields))
field.isUsed = false
def copy() = new FieldSelector(udt, selection)
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import scala.reflect.ClassTag
abstract sealed class Field extends Serializable {
val localPos: Int
val globalPos: GlobalPos
var isUsed: Boolean = true
}
case class InputField(val localPos: Int, val globalPos: GlobalPos = new GlobalPos) extends Field
case class OutputField(val localPos: Int, val globalPos: GlobalPos = new GlobalPos) extends Field
class FieldSet[+FieldType <: Field] private (private val fields: Seq[FieldType]) extends Serializable {
private var globalized: Boolean = false
def isGlobalized = globalized
def setGlobalized(): Unit = {
assert(!globalized, "Field set has already been globalized")
globalized = true
}
def apply(localPos: Int): FieldType = {
fields.find(_.localPos == localPos).get
}
def select(selection: Seq[Int]): FieldSet[FieldType] = {
val outer = this
new FieldSet[FieldType](selection map apply) {
override def setGlobalized() = outer.setGlobalized()
override def isGlobalized = outer.isGlobalized
}
}
def toSerializerIndexArray: Array[Int] = fields map {
case field if field.isUsed => field.globalPos.getValue
case _ => -1
} toArray
def toIndexSet: Set[Int] = fields.filter(_.isUsed).map(_.globalPos.getValue).toSet
def mapToArray[T: ClassTag](fun: FieldType => T): Array[T] = {
(fields map fun) toArray
}
}
object FieldSet {
def newInputSet(numFields: Int): FieldSet[InputField] = new FieldSet((0 until numFields) map { InputField(_) })
def newOutputSet(numFields: Int): FieldSet[OutputField] = new FieldSet((0 until numFields) map { OutputField(_) })
def newInputSet[T](udt: UDT[T]): FieldSet[InputField] = newInputSet(udt.numFields)
def newOutputSet[T](udt: UDT[T]): FieldSet[OutputField] = newOutputSet(udt.numFields)
implicit def toSeq[FieldType <: Field](fieldSet: FieldSet[FieldType]): Seq[FieldType] = fieldSet.fields
}
class GlobalPos extends Serializable {
private var pos: Either[Int, GlobalPos] = null
def getValue: Int = pos match {
case null => -1
case Left(index) => index
case Right(target) => target.getValue
}
def getIndex: Option[Int] = pos match {
case null | Right(_) => None
case Left(index) => Some(index)
}
def getReference: Option[GlobalPos] = pos match {
case null | Left(_) => None
case Right(target) => Some(target)
}
def resolve: GlobalPos = pos match {
case null => this
case Left(_) => this
case Right(target) => target.resolve
}
def isUnknown = pos == null
def isIndex = (pos != null) && pos.isLeft
def isReference = (pos != null) && pos.isRight
def setIndex(index: Int) = {
assert(pos == null || pos.isLeft, "Cannot convert a position reference to an index")
pos = Left(index)
}
def setReference(target: GlobalPos) = {
// assert(pos == null, "Cannot overwrite a known position with a reference")
pos = Right(target)
}
}
object GlobalPos {
object Unknown {
def unapply(globalPos: GlobalPos): Boolean = globalPos.isUnknown
}
object Index {
def unapply(globalPos: GlobalPos): Option[Int] = globalPos.getIndex
}
object Reference {
def unapply(globalPos: GlobalPos): Option[GlobalPos] = globalPos.getReference
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import java.util.{List => JList}
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.bufferAsJavaList
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.pact.generic.contract.DualInputContract
import eu.stratosphere.pact.generic.contract.SingleInputContract
import eu.stratosphere.scala.analysis.FieldSet.toSeq
import eu.stratosphere.scala.ScalaContract
import eu.stratosphere.pact.common.contract.FileDataSink
import eu.stratosphere.pact.common.contract.GenericDataSource
import eu.stratosphere.pact.common.contract.MapContract
import eu.stratosphere.scala.OneInputScalaContract
import eu.stratosphere.pact.common.contract.ReduceContract
import eu.stratosphere.scala.OneInputKeyedScalaContract
import eu.stratosphere.pact.common.contract.CrossContract
import eu.stratosphere.scala.TwoInputScalaContract
import eu.stratosphere.pact.common.contract.MatchContract
import eu.stratosphere.scala.TwoInputKeyedScalaContract
import eu.stratosphere.pact.common.stubs.CoGroupStub
import eu.stratosphere.pact.common.contract.CoGroupContract
import eu.stratosphere.scala.UnionScalaContract
import eu.stratosphere.scala.operators.CopyOperator
import eu.stratosphere.pact.generic.contract.BulkIteration
import eu.stratosphere.scala.BulkIterationScalaContract
import eu.stratosphere.pact.common.contract.GenericDataSink
object GlobalSchemaGenerator {
def initGlobalSchema(sinks: Seq[Contract with ScalaContract[_]]): Unit = {
sinks.foldLeft(0) { (freePos, contract) => globalizeContract(contract, Seq(), Map(), None, freePos) }
}
/**
* Computes disjoint write sets for a contract and its inputs.
*
* @param contract The contract to globalize
* @param parentInputs Input fields which should be bound to the contract's outputs
* @param proxies Provides contracts for iteration placeholders
* @param fixedOutputs Specifies required positions for the contract's output fields, or None to allocate new positions
* @param freePos The current first available position in the global schema
* @return The new first available position in the global schema
*/
private def globalizeContract(contract: Contract, parentInputs: Seq[FieldSet[InputField]], proxies: Map[Contract, Contract with ScalaContract[_]], fixedOutputs: Option[FieldSet[Field]], freePos: Int): Int = {
val contract4s = proxies.getOrElse(contract, contract.asInstanceOf[Contract with ScalaContract[_]])
parentInputs.foreach(contract4s.getUDF.attachOutputsToInputs)
contract4s.getUDF.outputFields.isGlobalized match {
case true => freePos
case false => {
val freePos1 = globalizeContract(contract4s, proxies, fixedOutputs, freePos)
eliminateNoOps(contract4s)
contract4s.persistConfiguration(None)
freePos1
}
}
}
private def globalizeContract(contract: Contract with ScalaContract[_], proxies: Map[Contract, Contract with ScalaContract[_]], fixedOutputs: Option[FieldSet[Field]], freePos: Int): Int = {
contract match {
case contract : FileDataSink with ScalaContract[_] => {
contract.getUDF.outputFields.setGlobalized()
globalizeContract(contract.getInputs().get(0), Seq(contract.getUDF.asInstanceOf[UDF1[_,_]].inputFields), proxies, None, freePos)
}
case contract: GenericDataSource[_] with ScalaContract[_] => {
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract : BulkIteration with BulkIterationScalaContract[_] => {
val s0contract = proxies.getOrElse(contract.getInputs().get(0), contract.getInputs().get(0).asInstanceOf[Contract with ScalaContract[_]])
val newProxies = proxies + (contract.getPartialSolution() -> s0contract)
val freePos1 = globalizeContract(contract.getInputs().get(0), Seq(), proxies, fixedOutputs, freePos)
val freePos2 = globalizeContract(contract.getNextPartialSolution(), Seq(), newProxies, Some(s0contract.getUDF.outputFields), freePos1)
val freePos3 = Option(contract.getTerminationCriterion()) map { globalizeContract(_, Seq(), newProxies, None, freePos2) } getOrElse freePos2
contract.getUDF.assignOutputGlobalIndexes(s0contract.getUDF.outputFields)
freePos3
}
// case contract @ WorksetIterate4sContract(s0, ws0, deltaS, newWS, placeholderS, placeholderWS) => {
//
// val s0contract = proxies.getOrElse(s0.get(0), s0.asInstanceOf[Contract with ScalaContract[_]])
// val ws0contract = proxies.getOrElse(ws0.get(0), ws0.asInstanceOf[Contract with ScalaContract[_]])
// val newProxies = proxies + (placeholderS -> s0contract) + (placeholderWS -> ws0contract)
//
// val freePos1 = globalizeContract(s0.get(0), Seq(contract.key.inputFields), proxies, fixedOutputs, freePos)
// val freePos2 = globalizeContract(ws0.get(0), Seq(), proxies, None, freePos1)
// val freePos3 = globalizeContract(deltaS, Seq(), newProxies, Some(s0contract.getUDF.outputFields), freePos2)
// val freePos4 = globalizeContract(newWS, Seq(), newProxies, Some(ws0contract.getUDF.outputFields), freePos3)
//
// contract.udf.assignOutputGlobalIndexes(s0contract.getUDF.outputFields)
//
// freePos4
// }
case contract : CoGroupContract with TwoInputKeyedScalaContract[_, _, _] => {
val freePos1 = globalizeContract(contract.getFirstInputs().get(0), Seq(contract.getUDF.leftInputFields, contract.leftKey.inputFields), proxies, None, freePos)
val freePos2 = globalizeContract(contract.getSecondInputs().get(0), Seq(contract.getUDF.rightInputFields, contract.rightKey.inputFields), proxies, None, freePos1)
contract.getUDF.setOutputGlobalIndexes(freePos2, fixedOutputs)
}
case contract: CrossContract with TwoInputScalaContract[_, _, _] => {
val freePos1 = globalizeContract(contract.getFirstInputs().get(0), Seq(contract.getUDF.leftInputFields), proxies, None, freePos)
val freePos2 = globalizeContract(contract.getSecondInputs().get(0), Seq(contract.getUDF.rightInputFields), proxies, None, freePos1)
contract.getUDF.setOutputGlobalIndexes(freePos2, fixedOutputs)
}
case contract : MatchContract with TwoInputKeyedScalaContract[_, _, _] => {
val freePos1 = globalizeContract(contract.getFirstInputs().get(0), Seq(contract.getUDF.leftInputFields, contract.leftKey.inputFields), proxies, None, freePos)
val freePos2 = globalizeContract(contract.getSecondInputs().get(0), Seq(contract.getUDF.rightInputFields, contract.rightKey.inputFields), proxies, None, freePos1)
contract.getUDF.setOutputGlobalIndexes(freePos2, fixedOutputs)
}
case contract : MapContract with OneInputScalaContract[_, _] => {
val freePos1 = globalizeContract(contract.getInputs().get(0), Seq(contract.getUDF.inputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos1, fixedOutputs)
}
case contract : ReduceContract with OneInputKeyedScalaContract[_, _] => {
val freePos1 = globalizeContract(contract.getInputs().get(0), Seq(contract.getUDF.inputFields, contract.key.inputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos1, fixedOutputs)
}
case contract : MapContract with UnionScalaContract[_] => {
// Determine where this contract's children should write their output
val freePos1 = contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
val inputs = contract.getInputs()
// If an input hasn't yet allocated its output fields, then we can force them into
// the expected position. Otherwise, the output fields must be physically copied.
for (idx <- 0 until inputs.size()) {
val input = inputs.get(idx)
val input4s = proxies.getOrElse(input, input.asInstanceOf[Contract with ScalaContract[_]])
if (input4s.getUDF.outputFields.isGlobalized || input4s.getUDF.outputFields.exists(_.globalPos.isReference)) {
// inputs.set(idx, CopyOperator(input4s))
throw new RuntimeException("Copy operator needed, not yet implemented properly.")
}
}
inputs.foldLeft(freePos1) { (freePos2, input) =>
globalizeContract(input, Seq(), proxies, Some(contract.getUDF.outputFields), freePos2)
}
}
}
}
private def eliminateNoOps(contract: Contract): Unit = {
def elim(children: JList[Contract]): Unit = {
val newChildren = children flatMap {
case c: MapContract with UnionScalaContract[_] => c.getInputs()
case child => List(child)
}
children.clear()
children.addAll(newChildren)
}
contract match {
case c: SingleInputContract[_] => elim(c.getInputs())
case c: DualInputContract[_] => elim(c.getFirstInputs()); elim(c.getSecondInputs())
case c: GenericDataSink => elim(c.getInputs())
case _ =>
}
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import scala.Array.canBuildFrom
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.collectionAsScalaIterable
import Extractors.CoGroupNode
import Extractors.CrossNode
import Extractors.DataSinkNode
import Extractors.DataSourceNode
import Extractors.JoinNode
import Extractors.MapNode
import Extractors.ReduceNode
import eu.stratosphere.pact.common.contract.GenericDataSink
import eu.stratosphere.pact.common.plan.Plan
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.pact.generic.contract.DualInputContract
import eu.stratosphere.pact.generic.contract.SingleInputContract
import eu.stratosphere.pact.generic.contract.BulkIteration
import eu.stratosphere.pact.generic.contract.WorksetIteration
object GlobalSchemaPrinter {
import Extractors._
def printSchema(plan: Plan): Unit = {
println("### " + plan.getJobName + " ###")
plan.getDataSinks.foldLeft(Set[Contract]())(printSchema)
println("####" + ("#" * plan.getJobName.length) + "####")
println()
}
private def printSchema(visited: Set[Contract], node: Contract): Set[Contract] = {
visited.contains(node) match {
case true => visited
case false => {
val children = node match {
case bi: BulkIteration => bi.getInputs().toList :+ bi.getNextPartialSolution()
case wi: WorksetIteration => wi.getInitialSolutionSet().toList ++ wi.getInitialWorkset().toList :+ wi.getSolutionSetDelta() :+ wi.getNextWorkset()
case si : SingleInputContract[_] => si.getInputs().toList
case di : DualInputContract[_] => di.getFirstInputs().toList ++ di.getSecondInputs().toList
case gds : GenericDataSink => gds.getInputs().toList
case _ => List()
}
val newVisited = children.foldLeft(visited + node)(printSchema)
node match {
case _ : BulkIteration.PartialSolutionPlaceHolder =>
case _ : WorksetIteration.SolutionSetPlaceHolder =>
case _ : WorksetIteration.WorksetPlaceHolder =>
case DataSinkNode(udf, input) => {
printInfo(node, "Sink",
Seq(),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case DataSourceNode(udf) => {
printInfo(node, "Source",
Seq(),
Seq(),
Seq(),
Seq(),
udf.outputFields
)
}
case CoGroupNode(udf, leftKey, rightKey, leftInput, rightInput) => {
printInfo(node, "CoGroup",
Seq(("L", leftKey), ("R", rightKey)),
Seq(("L", udf.leftInputFields), ("R", udf.rightInputFields)),
Seq(("L", udf.getLeftForwardIndexArray), ("R", udf.getRightForwardIndexArray)),
Seq(("L", udf.getLeftDiscardIndexArray), ("R", udf.getRightDiscardIndexArray)),
udf.outputFields
)
}
case CrossNode(udf, leftInput, rightInput) => {
printInfo(node, "Cross",
Seq(),
Seq(("L", udf.leftInputFields), ("R", udf.rightInputFields)),
Seq(("L", udf.getLeftForwardIndexArray), ("R", udf.getRightForwardIndexArray)),
Seq(("L", udf.getLeftDiscardIndexArray), ("R", udf.getRightDiscardIndexArray)),
udf.outputFields
)
}
case JoinNode(udf, leftKey, rightKey, leftInput, rightInput) => {
printInfo(node, "Join",
Seq(("L", leftKey), ("R", rightKey)),
Seq(("L", udf.leftInputFields), ("R", udf.rightInputFields)),
Seq(("L", udf.getLeftForwardIndexArray), ("R", udf.getRightForwardIndexArray)),
Seq(("L", udf.getLeftDiscardIndexArray), ("R", udf.getRightDiscardIndexArray)),
udf.outputFields
)
}
case MapNode(udf, input) => {
printInfo(node, "Map",
Seq(),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case UnionNode(udf, input) => {
printInfo(node, "Union",
Seq(),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case ReduceNode(udf, key, input) => {
// val contract = node.asInstanceOf[Reduce4sContract[_, _, _]]
// contract.userCombineCode map { _ =>
// printInfo(node, "Combine",
// Seq(("", key)),
// Seq(("", udf.inputFields)),
// Seq(("", contract.combineForwardSet.toArray)),
// Seq(("", contract.combineDiscardSet.toArray)),
// udf.inputFields
// )
// }
printInfo(node, "Reduce",
Seq(("", key)),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case WorksetIterationNode(udf, key, input1, input2) => {
printInfo(node, "WorksetIterate",
Seq(("", key)),
Seq(),
Seq(),
Seq(),
udf.outputFields)
}
case BulkIterationNode(udf, input1) => {
printInfo(node, "BulkIterate",
Seq(),
Seq(),
Seq(),
Seq(),
udf.outputFields)
}
}
newVisited
}
}
}
private def printInfo(node: Contract, kind: String, keys: Seq[(String, FieldSelector)], reads: Seq[(String, FieldSet[_])], forwards: Seq[(String, Array[Int])], discards: Seq[(String, Array[Int])], writes: FieldSet[_]): Unit = {
def indexesToStrings(pre: String, indexes: Array[Int]) = indexes map {
case -1 => "_"
case i => pre + i
}
val formatString = "%s (%s): K{%s}: R[%s] => F[%s] - D[%s] + W[%s]"
val name = node.getName
val sKeys = keys flatMap { case (pre, value) => value.selectedFields.toSerializerIndexArray.map(pre + _) } mkString ", "
val sReads = reads flatMap { case (pre, value) => indexesToStrings(pre, value.toSerializerIndexArray) } mkString ", "
val sForwards = forwards flatMap { case (pre, value) => value.sorted.map(pre + _) } mkString ", "
val sDiscards = discards flatMap { case (pre, value) => value.sorted.map(pre + _) } mkString ", "
val sWrites = indexesToStrings("", writes.toSerializerIndexArray) mkString ", "
println(formatString.format(name, kind, sKeys, sReads, sForwards, sDiscards, sWrites))
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import java.util.{List => JList}
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.bufferAsJavaList
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.pact.generic.contract.DualInputContract
import eu.stratosphere.pact.generic.contract.SingleInputContract
import eu.stratosphere.scala.analysis.FieldSet.toSeq
import eu.stratosphere.scala.ScalaContract
import eu.stratosphere.pact.common.contract.FileDataSink
import eu.stratosphere.pact.common.contract.GenericDataSource
import eu.stratosphere.pact.common.contract.MapContract
import eu.stratosphere.scala.OneInputScalaContract
import eu.stratosphere.pact.common.contract.ReduceContract
import eu.stratosphere.scala.OneInputKeyedScalaContract
import eu.stratosphere.pact.common.contract.CrossContract
import eu.stratosphere.scala.TwoInputScalaContract
import eu.stratosphere.pact.common.contract.MatchContract
import eu.stratosphere.scala.TwoInputKeyedScalaContract
import eu.stratosphere.pact.common.stubs.CoGroupStub
import eu.stratosphere.pact.common.contract.CoGroupContract
import eu.stratosphere.scala.UnionScalaContract
import eu.stratosphere.scala.operators.CopyOperator
import eu.stratosphere.pact.generic.contract.BulkIteration
import eu.stratosphere.scala.BulkIterationScalaContract
import eu.stratosphere.pact.common.contract.GenericDataSink
import eu.stratosphere.scala.WorksetIterationScalaContract
import eu.stratosphere.pact.generic.contract.WorksetIteration
object NewGlobalSchemaGenerator {
def initGlobalSchema(sinks: Seq[Contract with ScalaContract[_]]): Unit = {
sinks.foldLeft(0) { (freePos, contract) => globalizeContract(contract, Seq(), Map(), None, freePos) }
}
/**
* Computes disjoint write sets for a contract and its inputs.
*
* @param contract The contract to globalize
* @param parentInputs Input fields which should be bound to the contract's outputs
* @param proxies Provides contracts for iteration placeholders
* @param fixedOutputs Specifies required positions for the contract's output fields, or None to allocate new positions
* @param freePos The current first available position in the global schema
* @return The new first available position in the global schema
*/
private def globalizeContract(contract: Contract, parentInputs: Seq[FieldSet[InputField]], proxies: Map[Contract, Contract with ScalaContract[_]], fixedOutputs: Option[FieldSet[Field]], freePos: Int): Int = {
val contract4s = proxies.getOrElse(contract, contract.asInstanceOf[Contract with ScalaContract[_]])
parentInputs.foreach(contract4s.getUDF.attachOutputsToInputs)
contract4s.getUDF.outputFields.isGlobalized match {
case true => freePos
case false => {
val freePos1 = globalizeContract(contract4s, proxies, fixedOutputs, freePos)
eliminateNoOps(contract4s)
contract4s.persistConfiguration(None)
freePos1
}
}
}
private def globalizeContract(contract: Contract with ScalaContract[_], proxies: Map[Contract, Contract with ScalaContract[_]], fixedOutputs: Option[FieldSet[Field]], freePos: Int): Int = {
contract match {
case contract : FileDataSink with ScalaContract[_] => {
contract.getUDF.outputFields.setGlobalized()
globalizeContract(contract.getInputs().get(0), Seq(contract.getUDF.asInstanceOf[UDF1[_,_]].inputFields), proxies, None, freePos)
}
case contract: GenericDataSource[_] with ScalaContract[_] => {
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract : BulkIteration with BulkIterationScalaContract[_] => {
val s0contract = proxies.getOrElse(contract.getInputs().get(0), contract.getInputs().get(0).asInstanceOf[Contract with ScalaContract[_]])
val newProxies = proxies + (contract.getPartialSolution() -> s0contract)
val freePos1 = globalizeContract(contract.getInputs().get(0), Seq(), proxies, fixedOutputs, freePos)
val freePos2 = globalizeContract(contract.getNextPartialSolution(), Seq(), newProxies, Some(s0contract.getUDF.outputFields), freePos1)
val freePos3 = Option(contract.getTerminationCriterion()) map { globalizeContract(_, Seq(), newProxies, None, freePos2) } getOrElse freePos2
contract.getUDF.assignOutputGlobalIndexes(s0contract.getUDF.outputFields)
freePos3
}
case contract : WorksetIteration with WorksetIterationScalaContract[_] => {
val s0 = contract.getInitialSolutionSet().get(0)
val ws0 = contract.getInitialWorkset().get(0)
val deltaS = contract.getSolutionSetDelta()
val newWS = contract.getNextWorkset()
val s0contract = proxies.getOrElse(s0, s0.asInstanceOf[Contract with ScalaContract[_]])
val ws0contract = proxies.getOrElse(ws0, ws0.asInstanceOf[Contract with ScalaContract[_]])
val newProxies = proxies + (contract.getSolutionSet() -> s0contract) + (contract.getWorkset() -> ws0contract)
val freePos1 = globalizeContract(s0, Seq(contract.key.inputFields), proxies, fixedOutputs, freePos)
val freePos2 = globalizeContract(ws0, Seq(), proxies, None, freePos1)
val freePos3 = globalizeContract(deltaS, Seq(), newProxies, Some(s0contract.getUDF.outputFields), freePos2)
val freePos4 = globalizeContract(newWS, Seq(), newProxies, Some(ws0contract.getUDF.outputFields), freePos3)
contract.getUDF.assignOutputGlobalIndexes(s0contract.getUDF.outputFields)
freePos4
}
case contract : CoGroupContract with TwoInputKeyedScalaContract[_, _, _] => {
val freePos1 = globalizeContract(contract.getFirstInputs().get(0), Seq(contract.getUDF.leftInputFields, contract.leftKey.inputFields), proxies, None, freePos)
val freePos2 = globalizeContract(contract.getSecondInputs().get(0), Seq(contract.getUDF.rightInputFields, contract.rightKey.inputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract: CrossContract with TwoInputScalaContract[_, _, _] => {
val freePos1 = globalizeContract(contract.getFirstInputs().get(0), Seq(contract.getUDF.leftInputFields), proxies, None, freePos)
val freePos2 = globalizeContract(contract.getSecondInputs().get(0), Seq(contract.getUDF.rightInputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract : MatchContract with TwoInputKeyedScalaContract[_, _, _] => {
val freePos1 = globalizeContract(contract.getFirstInputs().get(0), Seq(contract.getUDF.leftInputFields, contract.leftKey.inputFields), proxies, None, freePos)
val freePos2 = globalizeContract(contract.getSecondInputs().get(0), Seq(contract.getUDF.rightInputFields, contract.rightKey.inputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract : MapContract with OneInputScalaContract[_, _] => {
val freePos1 = globalizeContract(contract.getInputs().get(0), Seq(contract.getUDF.inputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract : ReduceContract with OneInputKeyedScalaContract[_, _] => {
val freePos1 = globalizeContract(contract.getInputs().get(0), Seq(contract.getUDF.inputFields, contract.key.inputFields), proxies, None, freePos)
contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
}
case contract : MapContract with UnionScalaContract[_] => {
// Determine where this contract's children should write their output
val freePos1 = contract.getUDF.setOutputGlobalIndexes(freePos, fixedOutputs)
val inputs = contract.getInputs()
// If an input hasn't yet allocated its output fields, then we can force them into
// the expected position. Otherwise, the output fields must be physically copied.
for (idx <- 0 until inputs.size()) {
val input = inputs.get(idx)
val input4s = proxies.getOrElse(input, input.asInstanceOf[Contract with ScalaContract[_]])
// if (input4s.getUDF.outputFields.isGlobalized || input4s.getUDF.outputFields.exists(_.globalPos.isReference)) {
//// inputs.set(idx, CopyOperator(input4s))
//// throw new RuntimeException("Copy operator needed, not yet implemented properly.")
// println("Copy operator needed, not yet implemented properly.")
// println(input4s + " is globalized: " + input4s.getUDF.outputFields.isGlobalized)
// println(input4s + " has global field.isReference: " + input4s.getUDF.outputFields.exists(_.globalPos.isReference))
// }
}
inputs.foldLeft(freePos) { (freePos2, input) =>
globalizeContract(input, Seq(), proxies, Some(contract.getUDF.outputFields), freePos)
}
}
}
}
private def eliminateNoOps(contract: Contract): Unit = {
def elim(children: JList[Contract]): Unit = {
val newChildren = children flatMap {
case c: MapContract with UnionScalaContract[_] => c.getInputs()
case child => List(child)
}
children.clear()
children.addAll(newChildren)
}
contract match {
case c: SingleInputContract[_] => elim(c.getInputs())
case c: DualInputContract[_] => elim(c.getFirstInputs()); elim(c.getSecondInputs())
case c: GenericDataSink => elim(c.getInputs())
case _ =>
}
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import scala.collection.mutable
abstract class UDF[R] extends Serializable {
val outputUDT: UDT[R]
val outputFields = FieldSet.newOutputSet(outputUDT)
def getOutputSerializer = outputUDT.getSerializer(outputFields.toSerializerIndexArray)
def getOutputLength = {
val indexes = outputFields.toIndexSet
if (indexes.isEmpty) {
0
} else {
indexes.max + 1
}
}
def allocateOutputGlobalIndexes(startPos: Int): Int = {
outputFields.setGlobalized()
outputFields.map(_.globalPos).foldLeft(startPos) {
case (i, gPos @ GlobalPos.Unknown()) => gPos.setIndex(i); i + 1
case (i, _) => i + 1
}
startPos
}
def assignOutputGlobalIndexes(sameAs: FieldSet[Field]): Unit = {
outputFields.setGlobalized()
outputFields.foreach {
case OutputField(localPos, globalPos) => globalPos.setReference(sameAs(localPos).globalPos)
}
}
def setOutputGlobalIndexes(startPos: Int, sameAs: Option[FieldSet[Field]]): Int = sameAs match {
case None => allocateOutputGlobalIndexes(startPos)
case Some(sameAs) => assignOutputGlobalIndexes(sameAs); startPos
}
def attachOutputsToInputs(inputFields: FieldSet[InputField]): Unit = {
inputFields.setGlobalized()
inputFields.foreach {
case InputField(localPos, globalPos) => globalPos.setReference(outputFields(localPos).globalPos)
}
}
protected def markFieldCopied(inputGlobalPos: GlobalPos, outputLocalPos: Int): Unit = {
val outputField = outputFields(outputLocalPos)
outputField.globalPos.setReference(inputGlobalPos)
outputField.isUsed = false
}
}
class UDF0[R](val outputUDT: UDT[R]) extends UDF[R]
class UDF1[T, R](val inputUDT: UDT[T], val outputUDT: UDT[R]) extends UDF[R] {
val inputFields = FieldSet.newInputSet(inputUDT)
val forwardSet = mutable.Set[GlobalPos]()
val discardSet = mutable.Set[GlobalPos]()
def getInputDeserializer = inputUDT.getSerializer(inputFields.toSerializerIndexArray)
def getForwardIndexArray = forwardSet.map(_.getValue).toArray
def getDiscardIndexArray = discardSet.map(_.getValue).toArray
override def getOutputLength = {
val forwardMax = if (forwardSet.isEmpty) -1 else forwardSet.map(_.getValue).max
math.max(super.getOutputLength, forwardMax + 1)
}
def markInputFieldUnread(localPos: Int): Unit = {
inputFields(localPos).isUsed = false
}
def markFieldCopied(inputLocalPos: Int, outputLocalPos: Int): Unit = {
val inputGlobalPos = inputFields(inputLocalPos).globalPos
forwardSet.add(inputGlobalPos)
markFieldCopied(inputGlobalPos, outputLocalPos)
}
}
class UDF2[T1, T2, R](val leftInputUDT: UDT[T1], val rightInputUDT: UDT[T2], val outputUDT: UDT[R]) extends UDF[R] {
val leftInputFields = FieldSet.newInputSet(leftInputUDT)
val leftForwardSet = mutable.Set[GlobalPos]()
val leftDiscardSet = mutable.Set[GlobalPos]()
val rightInputFields = FieldSet.newInputSet(rightInputUDT)
val rightForwardSet = mutable.Set[GlobalPos]()
val rightDiscardSet = mutable.Set[GlobalPos]()
def getLeftInputDeserializer = leftInputUDT.getSerializer(leftInputFields.toSerializerIndexArray)
def getLeftForwardIndexArray = leftForwardSet.map(_.getValue).toArray
def getLeftDiscardIndexArray = leftDiscardSet.map(_.getValue).toArray
def getRightInputDeserializer = rightInputUDT.getSerializer(rightInputFields.toSerializerIndexArray)
def getRightForwardIndexArray = rightForwardSet.map(_.getValue).toArray
def getRightDiscardIndexArray = rightDiscardSet.map(_.getValue).toArray
override def getOutputLength = {
val leftForwardMax = if (leftForwardSet.isEmpty) -1 else leftForwardSet.map(_.getValue).max
val rightForwardMax = if (rightForwardSet.isEmpty) -1 else rightForwardSet.map(_.getValue).max
math.max(super.getOutputLength, math.max(leftForwardMax, rightForwardMax) + 1)
}
private def getInputField(localPos: Either[Int, Int]): InputField = localPos match {
case Left(pos) => leftInputFields(pos)
case Right(pos) => rightInputFields(pos)
}
def markInputFieldUnread(localPos: Either[Int, Int]): Unit = {
localPos.fold(leftInputFields(_), rightInputFields(_)).isUsed = false
}
def markFieldCopied(inputLocalPos: Either[Int, Int], outputLocalPos: Int): Unit = {
val (fields, forwardSet) = inputLocalPos.fold(_ => (leftInputFields, leftForwardSet), _ => (rightInputFields, rightForwardSet))
val inputGlobalPos = fields(inputLocalPos.merge).globalPos
forwardSet.add(inputGlobalPos)
markFieldCopied(inputGlobalPos, outputLocalPos)
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis
import scala.collection.GenTraversableOnce
import scala.collection.generic.CanBuildFrom
import eu.stratosphere.pact.common.`type`.{ Key => PactKey }
import eu.stratosphere.pact.common.`type`.{ Value => PactValue }
import eu.stratosphere.pact.common.`type`.base.PactList
import eu.stratosphere.pact.common.`type`.base.PactString
import eu.stratosphere.pact.common.`type`.PactRecord
abstract class UDT[T] extends Serializable {
protected def createSerializer(indexMap: Array[Int]): UDTSerializer[T]
val fieldTypes: Array[Class[_ <: eu.stratosphere.pact.common.`type`.Value]]
val udtIdMap: Map[Int, Int]
def numFields = fieldTypes.length
def getSelectionIndices(selection: List[Int]) = {
selection map { udtIdMap.getOrElse(_, -1) }
}
def getKeySet(fields: Seq[Int]): Array[Class[_ <: PactKey]] = {
fields map { fieldNum => fieldTypes(fieldNum).asInstanceOf[Class[_ <: PactKey]] } toArray
}
def getSerializer(indexMap: Array[Int]): UDTSerializer[T] = {
val ser = createSerializer(indexMap)
ser
}
@transient private var defaultSerializer: UDTSerializer[T] = null
def getSerializerWithDefaultLayout: UDTSerializer[T] = {
// This method will be reentrant if T is a recursive type
if (defaultSerializer == null) {
defaultSerializer = createSerializer((0 until numFields) toArray)
}
defaultSerializer
}
}
abstract class UDTSerializer[T](val indexMap: Array[Int]) {
def serialize(item: T, record: PactRecord)
def deserializeRecyclingOff(record: PactRecord): T
def deserializeRecyclingOn(record: PactRecord): T
}
trait UDTLowPriorityImplicits {
class UDTAnalysisFailedException extends RuntimeException("UDT analysis failed. This should have been caught at compile time.")
class UDTSelectionFailedException(msg: String) extends RuntimeException(msg)
implicit def unanalyzedUDT[T]: UDT[T] = throw new UDTAnalysisFailedException
}
object UDT extends UDTLowPriorityImplicits {
// UDTs needed by library code
object NothingUDT extends UDT[Nothing] {
override val fieldTypes = Array[Class[_ <: PactValue]]()
override val udtIdMap: Map[Int, Int] = Map()
override def createSerializer(indexMap: Array[Int]) = throw new UnsupportedOperationException("Cannot create UDTSerializer for type Nothing")
}
object StringUDT extends UDT[String] {
override val fieldTypes = Array[Class[_ <: PactValue]](classOf[PactString])
override val udtIdMap: Map[Int, Int] = Map()
override def createSerializer(indexMap: Array[Int]) = new UDTSerializer[String](indexMap) {
private val index = indexMap(0)
@transient private var pactField = new PactString()
// override def getFieldIndex(selection: Seq[String]): List[Int] = selection match {
// case Seq() => List(index)
//// case _ => invalidSelection(selection)
// case _ => throw new RuntimeException("Invalid selection: " + selection)
// }
override def serialize(item: String, record: PactRecord) = {
if (index >= 0) {
pactField.setValue(item)
record.setField(index, pactField)
}
}
override def deserializeRecyclingOff(record: PactRecord): String = {
if (index >= 0) {
record.getFieldInto(index, pactField)
pactField.getValue()
} else {
null
}
}
override def deserializeRecyclingOn(record: PactRecord): String = {
if (index >= 0) {
record.getFieldInto(index, pactField)
pactField.getValue()
} else {
null
}
}
private def readObject(in: java.io.ObjectInputStream) = {
in.defaultReadObject()
pactField = new PactString()
}
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import scala.collection.mutable
import scala.collection.JavaConversions._
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.contracts._
import eu.stratosphere.pact.compiler.plan._
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
object AmbientFieldDetector {
import Extractors._
import EdgeDependencySets.EdgeDependencySet
def updateAmbientFields(plan: OptimizedPlan, edgeDependencies: Map[PactConnection, EdgeDependencySet], outputPositions: Map[Int, GlobalPos]): Unit = {
plan.getDataSinks.map(_.getSinkNode).foldLeft(Set[OptimizerNode]())(updateAmbientFields(outputPositions, edgeDependencies))
}
private def updateAmbientFields(outputPositions: Map[Int, GlobalPos], edgeDependencies: Map[PactConnection, EdgeDependencySet])(visited: Set[OptimizerNode], node: OptimizerNode): Set[OptimizerNode] = {
visited.contains(node) match {
case true => visited
case false => {
node match {
case _: SinkJoiner | _: BinaryUnionNode =>
case DataSinkNode(udf, input) =>
case DataSourceNode(udf) =>
case CoGroupNode(udf, _, _, leftInput, rightInput) => {
val leftProvides = edgeDependencies(leftInput).childProvides
val rightProvides = edgeDependencies(rightInput).childProvides
val parentNeeds = edgeDependencies(node.getOutgoingConnections.head).childProvides
val writes = udf.outputFields.toIndexSet
populateSets(udf.leftForwardSet, udf.leftDiscardSet, leftProvides, parentNeeds, writes, outputPositions)
populateSets(udf.rightForwardSet, udf.rightDiscardSet, rightProvides, parentNeeds, writes, outputPositions)
}
case CrossNode(udf, leftInput, rightInput) => {
val leftProvides = edgeDependencies(leftInput).childProvides
val rightProvides = edgeDependencies(rightInput).childProvides
val parentNeeds = edgeDependencies(node.getOutgoingConnections.head).childProvides
val writes = udf.outputFields.toIndexSet
populateSets(udf.leftForwardSet, udf.leftDiscardSet, leftProvides, parentNeeds, writes, outputPositions)
populateSets(udf.rightForwardSet, udf.rightDiscardSet, rightProvides, parentNeeds, writes, outputPositions)
}
case JoinNode(udf, _, _, leftInput, rightInput) => {
val leftProvides = edgeDependencies(leftInput).childProvides
val rightProvides = edgeDependencies(rightInput).childProvides
val parentNeeds = edgeDependencies(node.getOutgoingConnections.head).childProvides
val writes = udf.outputFields.toIndexSet
populateSets(udf.leftForwardSet, udf.leftDiscardSet, leftProvides, parentNeeds, writes, outputPositions)
populateSets(udf.rightForwardSet, udf.rightDiscardSet, rightProvides, parentNeeds, writes, outputPositions)
}
case MapNode(udf, input) => {
val inputProvides = edgeDependencies(input).childProvides
val parentNeeds = edgeDependencies(node.getOutgoingConnections.head).childProvides
val writes = udf.outputFields.toIndexSet
populateSets(udf.forwardSet, udf.discardSet, inputProvides, parentNeeds, writes, outputPositions)
}
case ReduceNode(udf, _, input) => {
val inputProvides = edgeDependencies(input).childProvides
val parentNeeds = edgeDependencies(node.getOutgoingConnections.head).childProvides
val writes = udf.outputFields.toIndexSet
populateSets(udf.forwardSet, udf.discardSet, inputProvides, parentNeeds, writes, outputPositions)
}
}
node.getIncomingConnections.map(_.getSource).foldLeft(visited + node)(updateAmbientFields(outputPositions, edgeDependencies))
}
}
}
private def populateSets(forwards: mutable.Set[GlobalPos], discards: mutable.Set[GlobalPos], childProvides: Set[Int], parentNeeds: Set[Int], writes: Set[Int], outputPositions: Map[Int, GlobalPos]): Unit = {
forwards.clear()
forwards.addAll((parentNeeds -- writes).intersect(childProvides).map(outputPositions(_)))
discards.clear()
discards.addAll((childProvides -- parentNeeds -- writes).intersect(childProvides).map(outputPositions(_)))
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import scala.collection.JavaConversions._
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.contracts._
import eu.stratosphere.pact.compiler.plan._
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
object EdgeDependencySets {
import Extractors._
case class EdgeDependencySet(parentNeeds: Set[Int], childProvides: Set[Int] = Set())
def computeEdgeDependencySets(plan: OptimizedPlan, outputSets: Map[OptimizerNode, Set[Int]]): Map[PactConnection, EdgeDependencySet] = {
plan.getDataSinks.map(_.getSinkNode).foldLeft(Map[PactConnection, EdgeDependencySet]())(computeEdgeDependencySets(outputSets))
}
private def computeEdgeDependencySets(outputSets: Map[OptimizerNode, Set[Int]])(edgeDependencySets: Map[PactConnection, EdgeDependencySet], node: OptimizerNode): Map[PactConnection, EdgeDependencySet] = {
// breadth-first traversal: parentNeeds will be None if any parent has not yet been visited
val parentNeeds = node.getOutgoingConnections().foldLeft(Option(Set[Int]())) {
case (None, _) => None
case (Some(acc), parent) => edgeDependencySets.get(parent) map { acc ++ _.parentNeeds }
}
parentNeeds match {
case None => edgeDependencySets
case Some(parentNeeds) => computeEdgeDependencySets(node, parentNeeds, outputSets, edgeDependencySets)
}
}
private def computeEdgeDependencySets(node: OptimizerNode, parentNeeds: Set[Int], outputSets: Map[OptimizerNode, Set[Int]], edgeDependencySets: Map[PactConnection, EdgeDependencySet]): Map[PactConnection, EdgeDependencySet] = {
def updateEdges(needs: (PactConnection, Set[Int])*): Map[PactConnection, EdgeDependencySet] = {
val updParents = node.getOutgoingConnections().foldLeft(edgeDependencySets) { (edgeDependencySets, parent) =>
val entry = edgeDependencySets(parent)
edgeDependencySets.updated(parent, entry.copy(childProvides = parentNeeds))
}
needs.foldLeft(updParents) {
case (edgeDependencySets, (inConn, needs)) => {
val updInConn = edgeDependencySets.updated(inConn, EdgeDependencySet(needs))
computeEdgeDependencySets(outputSets)(updInConn, inConn.getSource)
}
}
}
for (udf <- node.getUDF) {
// suppress outputs that aren't needed by any parent
val writeFields = udf.outputFields filter { _.isUsed }
val unused = writeFields filterNot { f => parentNeeds.contains(f.globalPos.getValue) }
for (field <- unused) {
field.isUsed = false
if (field.globalPos.isIndex)
field.globalPos.setIndex(Int.MinValue)
}
}
node match {
case DataSinkNode(udf, input) => {
val needs = udf.inputFields.toIndexSet
updateEdges(input -> needs)
}
case DataSourceNode(udf) => {
updateEdges()
}
case CoGroupNode(udf, leftKey, rightKey, leftInput, rightInput) => {
val leftReads = udf.leftInputFields.toIndexSet ++ leftKey.selectedFields.toIndexSet
val rightReads = udf.rightInputFields.toIndexSet ++ rightKey.selectedFields.toIndexSet
val writes = udf.outputFields.toIndexSet
val parentPreNeeds = parentNeeds -- writes
val parentLeftNeeds = parentPreNeeds.intersect(outputSets(leftInput.getSource))
val parentRightNeeds = parentPreNeeds.intersect(outputSets(rightInput.getSource))
val leftForwards = udf.leftForwardSet.map(_.getValue)
val rightForwards = udf.rightForwardSet.map(_.getValue)
val (leftRes, rightRes) = parentLeftNeeds.intersect(parentRightNeeds).partition {
case index if leftForwards(index) && !rightForwards(index) => true
case index if !leftForwards(index) && rightForwards(index) => false
case _ => throw new UnsupportedOperationException("Schema conflict: cannot forward the same field from both sides of a two-input operator.")
}
val leftNeeds = (parentLeftNeeds -- rightRes) ++ leftReads
val rightNeeds = (parentRightNeeds -- leftRes) ++ rightReads
updateEdges(leftInput -> leftNeeds, rightInput -> rightNeeds)
}
case CrossNode(udf, leftInput, rightInput) => {
val leftReads = udf.leftInputFields.toIndexSet
val rightReads = udf.rightInputFields.toIndexSet
val writes = udf.outputFields.toIndexSet
val parentPreNeeds = parentNeeds -- writes
val parentLeftNeeds = parentPreNeeds.intersect(outputSets(leftInput.getSource))
val parentRightNeeds = parentPreNeeds.intersect(outputSets(rightInput.getSource))
val leftForwards = udf.leftForwardSet.map(_.getValue)
val rightForwards = udf.rightForwardSet.map(_.getValue)
val (leftRes, rightRes) = parentLeftNeeds.intersect(parentRightNeeds).partition {
case index if leftForwards(index) && !rightForwards(index) => true
case index if !leftForwards(index) && rightForwards(index) => false
case _ => throw new UnsupportedOperationException("Schema conflict: cannot forward the same field from both sides of a two-input operator.")
}
val leftNeeds = (parentLeftNeeds -- rightRes) ++ leftReads
val rightNeeds = (parentRightNeeds -- leftRes) ++ rightReads
updateEdges(leftInput -> leftNeeds, rightInput -> rightNeeds)
}
case JoinNode(udf, leftKey, rightKey, leftInput, rightInput) => {
val leftReads = udf.leftInputFields.toIndexSet ++ leftKey.selectedFields.toIndexSet
val rightReads = udf.rightInputFields.toIndexSet ++ rightKey.selectedFields.toIndexSet
val writes = udf.outputFields.toIndexSet
val parentPreNeeds = parentNeeds -- writes
val parentLeftNeeds = parentPreNeeds.intersect(outputSets(leftInput.getSource))
val parentRightNeeds = parentPreNeeds.intersect(outputSets(rightInput.getSource))
val leftForwards = udf.leftForwardSet.map(_.getValue)
val rightForwards = udf.rightForwardSet.map(_.getValue)
val (leftRes, rightRes) = parentLeftNeeds.intersect(parentRightNeeds).partition {
case index if leftForwards(index) && !rightForwards(index) => true
case index if !leftForwards(index) && rightForwards(index) => false
case _ => throw new UnsupportedOperationException("Schema conflict: cannot forward the same field from both sides of a two-input operator.")
}
val leftNeeds = (parentLeftNeeds -- rightRes) ++ leftReads
val rightNeeds = (parentRightNeeds -- leftRes) ++ rightReads
updateEdges(leftInput -> leftNeeds, rightInput -> rightNeeds)
}
case MapNode(udf, input) => {
val reads = udf.inputFields.toIndexSet
val writes = udf.outputFields.toIndexSet
val needs = parentNeeds -- writes ++ reads
updateEdges(input -> needs)
}
case ReduceNode(udf, key, input) => {
val reads = udf.inputFields.toIndexSet ++ key.selectedFields.toIndexSet
val writes = udf.outputFields.toIndexSet
val needs = parentNeeds -- writes ++ reads
updateEdges(input -> needs)
}
case _: SinkJoiner | _: BinaryUnionNode => {
updateEdges(node.getIncomingConnections.map(_ -> parentNeeds): _*)
}
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import eu.stratosphere.pact.common.contract.CoGroupContract
import eu.stratosphere.pact.common.contract.CrossContract
import eu.stratosphere.pact.common.contract.GenericDataSink
import eu.stratosphere.pact.common.contract.GenericDataSource
import eu.stratosphere.pact.common.contract.MapContract
import eu.stratosphere.pact.common.contract.MatchContract
import eu.stratosphere.pact.common.contract.ReduceContract
import eu.stratosphere.pact.compiler.plan.BinaryUnionNode
import eu.stratosphere.pact.compiler.plan.BulkIterationNode
import eu.stratosphere.pact.compiler.plan.CoGroupNode
import eu.stratosphere.pact.compiler.plan.CrossNode
import eu.stratosphere.pact.compiler.plan.DataSinkNode
import eu.stratosphere.pact.compiler.plan.DataSourceNode
import eu.stratosphere.pact.compiler.plan.MapNode
import eu.stratosphere.pact.compiler.plan.MatchNode
import eu.stratosphere.pact.compiler.plan.OptimizerNode
import eu.stratosphere.pact.compiler.plan.PactConnection
import eu.stratosphere.pact.compiler.plan.ReduceNode
import eu.stratosphere.pact.compiler.plan.SinkJoiner
import eu.stratosphere.pact.compiler.plan.WorksetIterationNode
import eu.stratosphere.pact.generic.contract.WorksetIteration
import eu.stratosphere.scala.ScalaContract
import eu.stratosphere.scala.OneInputScalaContract
import eu.stratosphere.scala.OneInputKeyedScalaContract
import eu.stratosphere.scala.TwoInputScalaContract
import eu.stratosphere.scala.TwoInputKeyedScalaContract
import eu.stratosphere.scala.WorksetIterationScalaContract
import eu.stratosphere.scala.analysis.FieldSelector
import eu.stratosphere.scala.analysis.UDF
import eu.stratosphere.scala.analysis.UDF0
import eu.stratosphere.scala.analysis.UDF1
import eu.stratosphere.scala.analysis.UDF2
import eu.stratosphere.scala.BulkIterationScalaContract
import eu.stratosphere.pact.generic.contract.BulkIteration
import eu.stratosphere.scala.UnionScalaContract
object Extractors {
implicit def nodeToGetUDF(node: OptimizerNode) = new {
def getUDF: Option[UDF[_]] = node match {
case _: SinkJoiner | _: BinaryUnionNode => None
case _ => {
Some(node.getPactContract.asInstanceOf[ScalaContract[_]].getUDF)
}
}
}
object DataSinkNode {
def unapply(node: OptimizerNode): Option[(UDF1[_, _], PactConnection)] = node match {
case node: DataSinkNode => node.getPactContract match {
case contract: GenericDataSink with OneInputScalaContract[_, _] => {
Some((contract.getUDF, node.getInputConnection))
}
case _ => None
}
case _ => None
}
}
object DataSourceNode {
def unapply(node: OptimizerNode): Option[(UDF0[_])] = node match {
case node: DataSourceNode => node.getPactContract() match {
case contract: GenericDataSource[_] with ScalaContract[_] => Some(contract.getUDF.asInstanceOf[UDF0[_]])
case _ => None
}
case _ => None
}
}
object CoGroupNode {
def unapply(node: OptimizerNode): Option[(UDF2[_, _, _], FieldSelector, FieldSelector, PactConnection, PactConnection)] = node match {
case node: CoGroupNode => node.getPactContract() match {
case contract: CoGroupContract with TwoInputKeyedScalaContract[_, _, _] => Some((contract.getUDF, contract.leftKey, contract.rightKey, node.getFirstIncomingConnection, node.getSecondIncomingConnection))
case _ => None
}
case _ => None
}
}
object CrossNode {
def unapply(node: OptimizerNode): Option[(UDF2[_, _, _], PactConnection, PactConnection)] = node match {
case node: CrossNode => node.getPactContract match {
case contract: CrossContract with TwoInputScalaContract[_, _, _] => Some((contract.getUDF, node.getFirstIncomingConnection, node.getSecondIncomingConnection))
case _ => None
}
case _ => None
}
}
object JoinNode {
def unapply(node: OptimizerNode): Option[(UDF2[_, _, _], FieldSelector, FieldSelector, PactConnection, PactConnection)] = node match {
case node: MatchNode => node.getPactContract match {
case contract: MatchContract with TwoInputKeyedScalaContract[_, _, _] => Some((contract.getUDF, contract.leftKey, contract.rightKey, node.getFirstIncomingConnection, node.getSecondIncomingConnection))
case _ => None
}
case _ => None
}
}
object MapNode {
def unapply(node: OptimizerNode): Option[(UDF1[_, _], PactConnection)] = node match {
case node: MapNode => node.getPactContract match {
case contract: MapContract with OneInputScalaContract[_, _] => Some((contract.getUDF, node.getIncomingConnection))
case _ => None
}
case _ => None
}
}
object UnionNode {
def unapply(node: OptimizerNode): Option[(UDF1[_, _], PactConnection)] = node match {
case node: MapNode => node.getPactContract match {
case contract: MapContract with UnionScalaContract[_] => Some((contract.getUDF, node.getIncomingConnection))
case _ => None
}
case _ => None
}
}
object ReduceNode {
def unapply(node: OptimizerNode): Option[(UDF1[_, _], FieldSelector, PactConnection)] = node match {
case node: ReduceNode => node.getPactContract match {
case contract: ReduceContract with OneInputKeyedScalaContract[_, _] => Some((contract.getUDF, contract.key, node.getIncomingConnection))
case _ => None
}
case _ => None
}
}
object WorksetIterationNode {
def unapply(node: OptimizerNode): Option[(UDF0[_], FieldSelector, PactConnection, PactConnection)] = node match {
case node: WorksetIterationNode => node.getPactContract match {
case contract: WorksetIteration with WorksetIterationScalaContract[_] => Some((contract.getUDF, contract.key, node.getFirstIncomingConnection(), node.getSecondIncomingConnection()))
case _ => None
}
case _ => None
}
}
object BulkIterationNode {
def unapply(node: OptimizerNode): Option[(UDF0[_], PactConnection)] = node match {
case node: BulkIterationNode => node.getPactContract match {
case contract: BulkIteration with BulkIterationScalaContract[_] => Some((contract.getUDF, node.getIncomingConnection()))
case _ => None
}
case _ => None
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import scala.collection.mutable
import scala.collection.JavaConversions._
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.contracts._
import eu.stratosphere.pact.compiler.plan._
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
object GlobalSchemaCompactor {
import Extractors._
def compactSchema(plan: OptimizedPlan): Unit = {
val (_, conflicts) = plan.getDataSinks.map(_.getSinkNode).foldLeft((Set[OptimizerNode](), Map[GlobalPos, Set[GlobalPos]]())) {
case ((visited, conflicts), node) => findConflicts(node, visited, conflicts)
}
// Reset all position indexes before reassigning them
conflicts.keys.foreach { _.setIndex(Int.MinValue) }
plan.getDataSinks.map(_.getSinkNode).foldLeft(Set[OptimizerNode]())(compactSchema(conflicts))
}
/**
* Two fields are in conflict when they exist in the same place (record) at the same time (plan node).
* If two fields are in conflict, then they must be assigned different indexes.
*
* p1 conflictsWith p2 =
* Exists(n in Nodes):
* p1 != p2 &&
* (
* (p1 in n.forwards && p2 in n.forwards) ||
* (p1 in n.forwards && p2 in n.outputs) ||
* (p2 in n.forwards && p1 in n.outputs)
* )
*/
private def findConflicts(node: OptimizerNode, visited: Set[OptimizerNode], conflicts: Map[GlobalPos, Set[GlobalPos]]): (Set[OptimizerNode], Map[GlobalPos, Set[GlobalPos]]) = {
visited.contains(node) match {
case true => (visited, conflicts)
case false => {
val (forwardPos, outputFields) = node.getUDF match {
case None => (Set[GlobalPos](), Set[OutputField]())
case Some(udf: UDF0[_]) => (Set[GlobalPos](), udf.outputFields.toSet)
case Some(udf: UDF1[_, _]) => (udf.forwardSet, udf.outputFields.toSet)
case Some(udf: UDF2[_, _, _]) => (udf.leftForwardSet ++ udf.rightForwardSet, udf.outputFields.toSet)
case _ => (Set[GlobalPos](), Set[OutputField]())
}
// resolve GlobalPos references to the instance that holds the actual index
val forwards = forwardPos map { _.resolve }
val outputs = outputFields filter { _.isUsed } map { _.globalPos.resolve }
val newConflictsF = forwards.foldLeft(conflicts) {
case (conflicts, fPos) => {
// add all other forwards and all outputs to this forward's conflict set
val fConflicts = conflicts.getOrElse(fPos, Set()) ++ (forwards filterNot { _ == fPos }) ++ outputs
conflicts.updated(fPos, fConflicts)
}
}
val newConflictsO = outputs.foldLeft(newConflictsF) {
case (conflicts, oPos) => {
// add all forwards to this output's conflict set
val oConflicts = conflicts.getOrElse(oPos, Set()) ++ forwards
conflicts.updated(oPos, oConflicts)
}
}
node.getIncomingConnections.map(_.getSource).foldLeft((visited + node, newConflictsO)) {
case ((visited, conflicts), node) => findConflicts(node, visited, conflicts)
}
}
}
}
/**
* Assign indexes bottom-up, giving lower values to fields with larger conflict sets.
* This ordering should do a decent job of minimizing the number of gaps between fields.
*/
private def compactSchema(conflicts: Map[GlobalPos, Set[GlobalPos]])(visited: Set[OptimizerNode], node: OptimizerNode): Set[OptimizerNode] = {
visited.contains(node) match {
case true => visited
case false => {
val newVisited = node.getIncomingConnections.map(_.getSource).foldLeft(visited + node)(compactSchema(conflicts))
val outputFields = node.getUDF match {
case None => Seq[OutputField]()
case Some(udf) => udf.outputFields filter { _.isUsed }
}
val outputs = outputFields map {
case field => {
val pos = field.globalPos.resolve
(pos, field.localPos, conflicts(pos) map { _.getValue })
}
} sortBy {
case (_, localPos, posConflicts) => (Int.MaxValue - posConflicts.size, localPos)
}
val initUsed = outputs map { _._1.getValue } filter { _ >= 0 } toSet
val used = outputs.filter(_._1.getValue < 0).foldLeft(initUsed) {
case (used, (pos, _, conflicts)) => {
val index = chooseIndexValue(used ++ conflicts)
pos.setIndex(index)
used + index
}
}
node.getUDF match {
case Some(udf: UDF1[_, _]) => updateDiscards(used, udf.discardSet)
case Some(udf: UDF2[_, _, _]) => updateDiscards(used, udf.leftDiscardSet, udf.rightDiscardSet)
case _ =>
}
newVisited
}
}
}
private def chooseIndexValue(used: Set[Int]): Int = {
var index = 0
while (used(index)) {
index = index + 1
}
index
}
private def updateDiscards(outputs: Set[Int], discardSets: mutable.Set[GlobalPos]*): Unit = {
for (discardSet <- discardSets) {
val overwrites = discardSet filter { pos => outputs.contains(pos.getValue) } toList
for (pos <- overwrites)
discardSet.remove(pos)
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.collectionAsScalaIterable
import eu.stratosphere.pact.compiler.plan.OptimizerNode
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
import eu.stratosphere.scala.ScalaContract
trait GlobalSchemaOptimizer {
import Extractors._
def optimizeSchema(plan: OptimizedPlan, compactSchema: Boolean): Unit = {
// val (outputSets, outputPositions) = OutputSets.computeOutputSets(plan)
// val edgeSchemas = EdgeDependencySets.computeEdgeDependencySets(plan, outputSets)
//
// AmbientFieldDetector.updateAmbientFields(plan, edgeSchemas, outputPositions)
//
// if (compactSchema) {
// GlobalSchemaCompactor.compactSchema(plan)
// }
GlobalSchemaPrinter.printSchema(plan)
plan.getDataSinks.map(_.getSinkNode).foldLeft(Set[OptimizerNode]())(persistConfiguration)
}
private def persistConfiguration(visited: Set[OptimizerNode], node: OptimizerNode): Set[OptimizerNode] = {
visited.contains(node) match {
case true => visited
case false => {
val children = node.getIncomingConnections.map(_.getSource).toSet
val newVisited = children.foldLeft(visited + node)(persistConfiguration)
node.getPactContract match {
case c: ScalaContract[_] => c.persistConfiguration(Some(node))
case _ =>
}
newVisited
}
}
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import scala.Array.canBuildFrom
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.collectionAsScalaIterable
import Extractors.CoGroupNode
import Extractors.CrossNode
import Extractors.DataSinkNode
import Extractors.DataSourceNode
import Extractors.JoinNode
import Extractors.MapNode
import Extractors.ReduceNode
import eu.stratosphere.pact.compiler.plan.BinaryUnionNode
import eu.stratosphere.pact.compiler.plan.OptimizerNode
import eu.stratosphere.pact.compiler.plan.SinkJoiner
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
import eu.stratosphere.scala.analysis.FieldSet
import eu.stratosphere.scala.analysis.FieldSelector
import eu.stratosphere.pact.common.plan.Plan
import eu.stratosphere.pact.generic.contract.Contract
import eu.stratosphere.pact.generic.contract.SingleInputContract
import eu.stratosphere.pact.generic.contract.DualInputContract
import eu.stratosphere.pact.common.contract.GenericDataSink
object GlobalSchemaPrinter {
import Extractors._
def printSchema(plan: OptimizedPlan): Unit = {
println("### " + plan.getJobName + " ###")
plan.getDataSinks.map(_.getSinkNode).foldLeft(Set[OptimizerNode]())(printSchema)
println("####" + ("#" * plan.getJobName.length) + "####")
println()
}
private def printSchema(visited: Set[OptimizerNode], node: OptimizerNode): Set[OptimizerNode] = {
visited.contains(node) match {
case true => visited
case false => {
val children = node.getIncomingConnections.map(_.getSource).toSet
val newVisited = children.foldLeft(visited + node)(printSchema)
node match {
case _: SinkJoiner | _: BinaryUnionNode =>
case DataSinkNode(udf, input) => {
printInfo(node, "Sink",
Seq(),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case DataSourceNode(udf) => {
printInfo(node, "Source",
Seq(),
Seq(),
Seq(),
Seq(),
udf.outputFields
)
}
case CoGroupNode(udf, leftKey, rightKey, leftInput, rightInput) => {
printInfo(node, "CoGroup",
Seq(("L", leftKey), ("R", rightKey)),
Seq(("L", udf.leftInputFields), ("R", udf.rightInputFields)),
Seq(("L", udf.getLeftForwardIndexArray), ("R", udf.getRightForwardIndexArray)),
Seq(("L", udf.getLeftDiscardIndexArray), ("R", udf.getRightDiscardIndexArray)),
udf.outputFields
)
}
case CrossNode(udf, leftInput, rightInput) => {
printInfo(node, "Cross",
Seq(),
Seq(("L", udf.leftInputFields), ("R", udf.rightInputFields)),
Seq(("L", udf.getLeftForwardIndexArray), ("R", udf.getRightForwardIndexArray)),
Seq(("L", udf.getLeftDiscardIndexArray), ("R", udf.getRightDiscardIndexArray)),
udf.outputFields
)
}
case JoinNode(udf, leftKey, rightKey, leftInput, rightInput) => {
printInfo(node, "Join",
Seq(("L", leftKey), ("R", rightKey)),
Seq(("L", udf.leftInputFields), ("R", udf.rightInputFields)),
Seq(("L", udf.getLeftForwardIndexArray), ("R", udf.getRightForwardIndexArray)),
Seq(("L", udf.getLeftDiscardIndexArray), ("R", udf.getRightDiscardIndexArray)),
udf.outputFields
)
}
case MapNode(udf, input) => {
printInfo(node, "Map",
Seq(),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case UnionNode(udf, input) => {
printInfo(node, "Union",
Seq(),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case ReduceNode(udf, key, input) => {
// val contract = node.asInstanceOf[Reduce4sContract[_, _, _]]
// contract.userCombineCode map { _ =>
// printInfo(node, "Combine",
// Seq(("", key)),
// Seq(("", udf.inputFields)),
// Seq(("", contract.combineForwardSet.toArray)),
// Seq(("", contract.combineDiscardSet.toArray)),
// udf.inputFields
// )
// }
printInfo(node, "Reduce",
Seq(("", key)),
Seq(("", udf.inputFields)),
Seq(("", udf.getForwardIndexArray)),
Seq(("", udf.getDiscardIndexArray)),
udf.outputFields
)
}
case WorksetIterationNode(udf, key, input1, input2) => {
printInfo(node, "WorksetIterate",
Seq(("", key)),
Seq(),
Seq(),
Seq(),
udf.outputFields)
}
case BulkIterationNode(udf, input1) => {
printInfo(node, "BulkIterate",
Seq(),
Seq(),
Seq(),
Seq(),
udf.outputFields)
}
}
newVisited
}
}
}
private def printInfo(node: OptimizerNode, kind: String, keys: Seq[(String, FieldSelector)], reads: Seq[(String, FieldSet[_])], forwards: Seq[(String, Array[Int])], discards: Seq[(String, Array[Int])], writes: FieldSet[_]): Unit = {
def indexesToStrings(pre: String, indexes: Array[Int]) = indexes map {
case -1 => "_"
case i => pre + i
}
val formatString = "%s (%s): K{%s}: R[%s] => F[%s] - D[%s] + W[%s]"
val name = node.getName
val sKeys = keys flatMap { case (pre, value) => value.selectedFields.toSerializerIndexArray.map(pre + _) } mkString ", "
val sReads = reads flatMap { case (pre, value) => indexesToStrings(pre, value.toSerializerIndexArray) } mkString ", "
val sForwards = forwards flatMap { case (pre, value) => value.sorted.map(pre + _) } mkString ", "
val sDiscards = discards flatMap { case (pre, value) => value.sorted.map(pre + _) } mkString ", "
val sWrites = indexesToStrings("", writes.toSerializerIndexArray) mkString ", "
println(formatString.format(name, kind, sKeys, sReads, sForwards, sDiscards, sWrites))
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.analysis.postPass
import scala.collection.JavaConversions._
import eu.stratosphere.scala.analysis._
import eu.stratosphere.scala.contracts._
import eu.stratosphere.pact.compiler.plan._
import eu.stratosphere.pact.compiler.plan.candidate.OptimizedPlan
object OutputSets {
import Extractors._
def computeOutputSets(plan: OptimizedPlan): (Map[OptimizerNode, Set[Int]], Map[Int, GlobalPos]) = {
val root = plan.getDataSinks.map(s => s.getSinkNode: OptimizerNode).reduceLeft((n1, n2) => new SinkJoiner(n1, n2))
val outputSets = computeOutputSets(Map[OptimizerNode, Set[GlobalPos]](), root)
val outputPositions = outputSets(root).map(pos => (pos.getValue, pos)).toMap
(outputSets.mapValues(_.map(_.getValue)), outputPositions)
}
private def computeOutputSets(outputSets: Map[OptimizerNode, Set[GlobalPos]], node: OptimizerNode): Map[OptimizerNode, Set[GlobalPos]] = {
outputSets.contains(node) match {
case true => outputSets
case false => {
val children = node.getIncomingConnections.map(_.getSource).toSet
val newOutputSets = children.foldLeft(outputSets)(computeOutputSets)
val childOutputs = children.map(newOutputSets(_)).flatten
val nodeOutputs = node.getUDF map { _.outputFields.filter(_.isUsed).map(_.globalPos).toSet } getOrElse Set()
newOutputSets.updated(node, childOutputs ++ nodeOutputs)
}
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
class Counter {
private var value: Int = 0
def next: Int = {
this.synchronized {
val current = value
value += 1
current
}
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
trait DeserializeMethodGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with TreeGen[C] with SerializerGen[C] with Loggers[C] =>
import c.universe._
protected def mkDeserialize(desc: UDTDescriptor, listImpls: Map[Int, Type]): List[Tree] = {
// val rootRecyclingOn = mkMethod("deserializeRecyclingOn", Flag.OVERRIDE | Flag.FINAL, List(("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), desc.tpe, {
val rootRecyclingOn = mkMethod("deserializeRecyclingOn", Flag.FINAL, List(("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), desc.tpe, {
val env = GenEnvironment(listImpls, "flat" + desc.id, false, true, true, true)
mkSingle(genDeserialize(desc, Ident("record"), env, Map()))
})
// val rootRecyclingOff = mkMethod("deserializeRecyclingOff", Flag.OVERRIDE | Flag.FINAL, List(("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), desc.tpe, {
val rootRecyclingOff = mkMethod("deserializeRecyclingOff", Flag.FINAL, List(("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), desc.tpe, {
val env = GenEnvironment(listImpls, "flat" + desc.id, false, false, true, true)
mkSingle(genDeserialize(desc, Ident("record"), env, Map()))
})
val aux = desc.getRecursiveRefs map { desc =>
mkMethod("deserialize" + desc.id, Flag.PRIVATE | Flag.FINAL, List(("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), desc.tpe, {
val env = GenEnvironment(listImpls, "boxed" + desc.id, true, false, false, true)
mkSingle(genDeserialize(desc, Ident("record"), env, Map()))
})
}
rootRecyclingOn +: rootRecyclingOff +: aux.toList
}
private def genDeserialize(desc: UDTDescriptor, source: Tree, env: GenEnvironment, scope: Map[Int, (String, Type)]): Seq[Tree] = desc match {
case PrimitiveDescriptor(id, _, default, _) => {
val chk = env.mkChkIdx(id)
val des = env.mkGetFieldInto(id, source)
val get = env.mkGetValue(id)
Seq(mkIf(chk, Block(List(des), get), default))
}
case BoxedPrimitiveDescriptor(id, tpe, _, _, box, _) => {
val des = env.mkGetFieldInto(id, source)
val chk = mkAnd(env.mkChkIdx(id), des)
val get = box(env.mkGetValue(id))
Seq(mkIf(chk, get, mkNull))
}
case list @ ListDescriptor(id, tpe, _, elem) => {
val chk = mkAnd(env.mkChkIdx(id), env.mkNotIsNull(id, source))
val (init, pactList) = env.reentrant match {
// This is a bit conservative, but avoids runtime checks
// and/or even more specialized deserialize() methods to
// track whether it's safe to reuse the list variable.
case true => {
val listTpe = env.listImpls(id)
val list = mkVal("list" + id, NoFlags, false, listTpe, New(TypeTree(listTpe), List(List())))
(list, Ident("list" + id: TermName))
}
case false => {
val clear = Apply(Select(env.mkSelectWrapper(id), "clear"), List())
(clear, env.mkSelectWrapper(id))
}
}
// val buildTpe = appliedType(builderClass.tpe, List(elem.tpe, tpe))
// val build = mkVal(env.methodSym, "b" + id, 0, false, buildTpe) { _ => Apply(Select(cbf(), "apply"), List()) }
// val userList = mkVal("b" + id, NoFlags, false, tpe, New(TypeTree(tpe), List(List())))
val buildTpe = mkBuilderOf(elem.tpe, tpe)
val cbf = c.inferImplicitValue(mkCanBuildFromOf(tpe, elem.tpe, tpe))
val build = mkVal("b" + id, NoFlags, false, buildTpe, Apply(Select(cbf, "apply": TermName), List()))
val des = env.mkGetFieldInto(id, source, pactList)
val body = genDeserializeList(elem, pactList, Ident("b" + id: TermName), env.copy(allowRecycling = false, chkNull = true), scope)
val stats = init +: des +: build +: body
Seq(mkIf(chk, Block(stats.init.toList, stats.last), mkNull))
}
// we have a mutable UDT and the context allows recycling
case CaseClassDescriptor(_, tpe, true, _, getters) if env.allowRecycling => {
val fields = getters filterNot { _.isBaseField } map {
case FieldAccessor(_, _, _, _, desc) => (desc.id, mkVal("v" + desc.id, NoFlags, false, desc.tpe, {
mkSingle(genDeserialize(desc, source, env, scope))
}), desc.tpe, "v" + desc.id)
}
val newScope = scope ++ (fields map { case (id, tree, tpe, name) => id -> (name, tpe) })
val stats = fields map { _._2 }
val setterStats = getters map {
case FieldAccessor(_, setter, fTpe, _, fDesc) => {
val (name, tpe) = newScope(fDesc.id)
val castVal = maybeMkAsInstanceOf(Ident(name: TermName))(c.WeakTypeTag(tpe), c.WeakTypeTag(fTpe))
env.mkCallSetMutableField(desc.id, setter, castVal)
}
}
val ret = env.mkSelectMutableUdtInst(desc.id)
(stats ++ setterStats) :+ ret
}
case CaseClassDescriptor(_, tpe, _, _, getters) => {
val fields = getters filterNot { _.isBaseField } map {
case FieldAccessor(_, _, _, _, desc) => (desc.id, mkVal("v" + desc.id, NoFlags, false, desc.tpe, {
mkSingle(genDeserialize(desc, source, env, scope))
}), desc.tpe, "v" + desc.id)
}
val newScope = scope ++ (fields map { case (id, tree, tpe, name) => id -> (name, tpe) })
val stats = fields map { _._2 }
val args = getters map {
case FieldAccessor(_, _, fTpe, _, fDesc) => {
val (name, tpe) = newScope(fDesc.id)
maybeMkAsInstanceOf(Ident(name: TermName))(c.WeakTypeTag(tpe), c.WeakTypeTag(fTpe))
}
}
val ret = New(TypeTree(tpe), List(args.toList))
stats :+ ret
}
case BaseClassDescriptor(_, tpe, Seq(tagField, baseFields @ _*), subTypes) => {
val fields = baseFields map {
case FieldAccessor(_, _, _, _, desc) => (desc.id, mkVal("v" + desc.id, NoFlags, false, desc.tpe, {
val special = desc match {
case d @ PrimitiveDescriptor(id, _, _, _) if id == tagField.desc.id => d.copy(default = Literal(Constant(-1)))
case _ => desc
}
mkSingle(genDeserialize(desc, source, env, scope))
}), desc.tpe, "v" + desc.id)
}
val newScope = scope ++ (fields map { case (id, tree, tpe, name) => id -> (name, tpe) })
val stats = fields map { _._2 }
val cases = subTypes.zipWithIndex.toList map {
case (dSubType, i) => {
val code = mkSingle(genDeserialize(dSubType, source, env, newScope))
val pat = Bind("tag": TermName, Literal(Constant(i)))
CaseDef(pat, EmptyTree, code)
}
}
val chk = env.mkChkIdx(tagField.desc.id)
val des = env.mkGetFieldInto(tagField.desc.id, source)
val get = env.mkGetValue(tagField.desc.id)
Seq(mkIf(chk, Block(stats.toList :+ des, Match(get, cases)), mkNull))
}
case RecursiveDescriptor(id, tpe, refId) => {
val chk = mkAnd(env.mkChkIdx(id), env.mkNotIsNull(id, source))
val rec = mkVal("record" + id, NoFlags, false, typeOf[eu.stratosphere.pact.common.`type`.PactRecord], New(TypeTree(typeOf[eu.stratosphere.pact.common.`type`.PactRecord]), List(List())))
val get = env.mkGetFieldInto(id, source, Ident("record" + id: TermName))
val des = env.mkCallDeserialize(refId, Ident("record" + id: TermName))
Seq(mkIf(chk, Block(List(rec, get), des), mkNull))
}
case _ => Seq(mkNull)
}
private def genDeserializeList(elem: UDTDescriptor, source: Tree, target: Tree, env: GenEnvironment, scope: Map[Int, (String, Type)]): Seq[Tree] = {
val size = mkVal("size", NoFlags, false, definitions.IntTpe, Apply(Select(source, "size"), List()))
val sizeHint = Apply(Select(target, "sizeHint"), List(Ident("size": TermName)))
val i = mkVar("i", NoFlags, false, definitions.IntTpe, mkZero)
val loop = mkWhile(Apply(Select(Ident("i": TermName), "$less"), List(Ident("size": TermName)))) {
val item = mkVal("item", NoFlags, false, getListElemWrapperType(elem, env), Apply(Select(source, "get"), List(Ident("i": TermName))))
val (stats, value) = elem match {
case PrimitiveDescriptor(_, _, _, wrapper) => (Seq(), env.mkGetValue(Ident("item": TermName)))
case BoxedPrimitiveDescriptor(_, _, _, wrapper, box, _) => (Seq(), box(env.mkGetValue(Ident("item": TermName))))
case ListDescriptor(id, tpe, _, innerElem) => {
// val buildTpe = appliedType(builderClass.tpe, List(innerElem.tpe, tpe))
// val build = mkVal(env.methodSym, "b" + id, 0, false, buildTpe) { _ => Apply(Select(cbf(), "apply"), List()) }
val buildTpe = mkBuilderOf(innerElem.tpe, tpe)
val cbf = c.inferImplicitValue(mkCanBuildFromOf(tpe, innerElem.tpe, tpe))
val build = mkVal("b" + id, NoFlags, false, buildTpe, Apply(Select(cbf, "apply": TermName), List()))
val body = mkVal("v" + id, NoFlags, false, elem.tpe,
mkSingle(genDeserializeList(innerElem, Ident("item": TermName), Ident("b" + id: TermName), env, scope)))
(Seq(build, body), Ident("v" + id: TermName))
}
case RecursiveDescriptor(id, tpe, refId) => (Seq(), env.mkCallDeserialize(refId, Ident("item": TermName)))
case _ => {
val body = genDeserialize(elem, Ident("item": TermName), env.copy(idxPrefix = "boxed" + elem.id, chkIndex = false, chkNull = false), scope)
val v = mkVal("v" + elem.id, NoFlags, false, elem.tpe, mkSingle(body))
(Seq(v), Ident("v" + elem.id: TermName))
}
}
val chk = env.mkChkNotNull(Ident("item": TermName), elem.tpe)
val add = Apply(Select(target, "$plus$eq"), List(value))
val addNull = Apply(Select(target, "$plus$eq"), List(mkNull))
val inc = Assign(Ident("i": TermName), Apply(Select(Ident("i": TermName), "$plus"), List(mkOne)))
Block(List(item, mkIf(chk, mkSingle(stats :+ add), addNull)), inc)
}
val get = Apply(Select(target, "result"), List())
Seq(size, sizeHint, i, loop, get)
}
private def getListElemWrapperType(desc: UDTDescriptor, env: GenEnvironment): Type = desc match {
case PrimitiveDescriptor(_, _, _, wrapper) => wrapper
case BoxedPrimitiveDescriptor(_, _, _, wrapper, _, _) => wrapper
case ListDescriptor(id, _, _, _) => env.listImpls(id)
case _ => typeOf[eu.stratosphere.pact.common.`type`.PactRecord]
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
trait Loggers[C <: Context] { this: MacroContextHolder[C] =>
import c.universe._
abstract sealed class LogLevel extends Ordered[LogLevel] {
protected[Loggers] val toInt: Int
override def compare(that: LogLevel) = this.toInt.compare(that.toInt)
}
object LogLevel {
def unapply(name: String): Option[LogLevel] = name match {
case "error" | "Error" => Some(Error)
case "warn" | "Warn" => Some(Warn)
case "debug" | "Debug" => Some(Debug)
case "inspect" | "Inspect" => Some(Inspect)
case _ => None
}
case object Error extends LogLevel { override val toInt = 1 }
case object Warn extends LogLevel { override val toInt = 2 }
case object Debug extends LogLevel { override val toInt = 3 }
case object Inspect extends LogLevel { override val toInt = 4 }
}
object logger { var level: LogLevel = LogLevel.Warn }
private val counter = new Counter
trait Logger {
abstract sealed class Severity {
protected val toInt: Int
protected def reportInner(msg: String, pos: Position)
protected def formatMsg(msg: String) = msg
def isEnabled = this.toInt <= logger.level.toInt
def report(msg: String) = {
if (isEnabled) {
reportInner(formatMsg(msg), c.enclosingPosition)
}
}
}
case object Error extends Severity {
override val toInt = LogLevel.Error.toInt
override def reportInner(msg: String, pos: Position) = c.error(pos, msg)
}
case object Warn extends Severity {
override val toInt = LogLevel.Warn.toInt
override def reportInner(msg: String, pos: Position) = c.warning(pos, msg)
}
case object Debug extends Severity {
override val toInt = LogLevel.Debug.toInt
override def reportInner(msg: String, pos: Position) = c.info(pos, msg, true)
}
def getMsgAndStackLine(e: Throwable) = {
val lines = e.getStackTrace.map(_.toString)
val relevant = lines filter { _.contains("eu.stratosphere") }
val stackLine = relevant.headOption getOrElse e.getStackTrace.toString
e.getMessage() + " @ " + stackLine
}
def posString(pos: Position): String = pos match {
case NoPosition => "?:?"
case _ => pos.line + ":" + pos.column
}
def safely(default: => Tree, inspect: Boolean)(onError: Throwable => String)(block: => Tree): Tree = {
try {
block
} catch {
case e:Throwable => {
Error.report(onError(e));
val ret = default
ret
}
}
}
def verbosely[T](obs: T => String)(block: => T): T = {
val ret = block
Debug.report(obs(ret))
ret
}
def maybeVerbosely[T](guard: T => Boolean)(obs: T => String)(block: => T): T = {
val ret = block
if (guard(ret)) Debug.report(obs(ret))
ret
}
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
class MacroContextHolder[C <: Context](val c: C)
object MacroContextHolder {
def newMacroHelper[C <: Context](c: C) = new MacroContextHolder[c.type](c)
with Loggers[c.type]
with UDTDescriptors[c.type]
with UDTAnalyzer[c.type]
with TreeGen[c.type]
with SerializerGen[c.type]
with SerializeMethodGen[c.type]
with DeserializeMethodGen[c.type]
with UDTGen[c.type]
with SelectionExtractor[c.type]
}
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
import scala.Option.option2Iterable
import eu.stratosphere.scala.analysis.FieldSelector
trait SelectionExtractor[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with UDTAnalyzer[C] with Loggers[C] with TreeGen[C] =>
import c.universe._
def getSelector[T: c.WeakTypeTag, R: c.WeakTypeTag](fun: c.Expr[T => R]): Expr[List[Int]] =
(new SelectionExtractorInstance with Logger).extract(fun)
class SelectionExtractorInstance { this: Logger =>
def extract[T: c.WeakTypeTag, R: c.WeakTypeTag](fun: c.Expr[T => R]): Expr[List[Int]] = {
val result = getSelector(fun.tree) match {
case Left(errs) => Left(errs.toList)
case Right(sels) => getUDTDescriptor(weakTypeOf[T]) match {
case UnsupportedDescriptor(id, tpe, errs) => Left(errs.toList)
case desc: UDTDescriptor => chkSelectors(desc, sels) match {
case Nil => Right(desc, sels map { _.tail })
case errs => Left(errs)
}
}
}
result match {
case Left(errs) => {
errs foreach { err => c.error(c.enclosingPosition, s"Error analyzing FieldSelector ${show(fun.tree)}: " + err) }
reify { throw new RuntimeException("Invalid key selector."); }
}
case Right((udtDesc, sels)) => {
val descs: List[Option[UDTDescriptor]] = sels flatMap { sel: List[String] => udtDesc.select(sel) }
val ids = descs map { _ map { _.id } }
ids forall { _.isDefined } match {
case false => {
c.error(c.enclosingPosition, s"Could not determine ids of key fields: ${ids}")
reify { throw new RuntimeException("Invalid key selector."); }
}
case true => {
val generatedIds = ids map { _.get } map { id => Literal(Constant(id: Int)) }
val generatedList = mkList(generatedIds)
reify {
val list = c.Expr[List[Int]](generatedList).splice
list
}
}
}
}
}
}
private def getSelector(tree: Tree): Either[List[String], List[List[String]]] = tree match {
case Function(List(p), body) => getSelector(body, Map(p.symbol -> Nil)) match {
case err @ Left(_) => err
case Right(sels) => Right(sels map { sel => p.name.toString +: sel })
}
case _ => Left(List("expected lambda expression literal but found " + show(tree)))
}
private def getSelector(tree: Tree, roots: Map[Symbol, List[String]]): Either[List[String], List[List[String]]] = tree match {
case SimpleMatch(body, bindings) => getSelector(body, roots ++ bindings)
case Match(_, List(CaseDef(pat, EmptyTree, _))) => Left(List("case pattern is too complex"))
case Match(_, List(CaseDef(_, guard, _))) => Left(List("case pattern is guarded"))
case Match(_, _ :: _ :: _) => Left(List("match contains more than one case"))
case TupleCtor(args) => {
val (errs, sels) = args.map(arg => getSelector(arg, roots)).partition(_.isLeft)
errs match {
case Nil => Right(sels.map(_.right.get).flatten)
case _ => Left(errs.map(_.left.get).flatten)
}
}
case Apply(tpt@TypeApply(_, _), _) => Left(List("constructor call on non-tuple type " + tpt.tpe))
case Ident(name) => roots.get(tree.symbol) match {
case Some(sel) => Right(List(sel))
case None => Left(List("unexpected identifier " + name))
}
case Select(src, member) => getSelector(src, roots) match {
case err @ Left(_) => err
case Right(List(sel)) => Right(List(sel :+ member.toString))
case _ => Left(List("unsupported selection"))
}
case _ => Left(List("unsupported construct of kind " + showRaw(tree)))
}
private object SimpleMatch {
def unapply(tree: Tree): Option[(Tree, Map[Symbol, List[String]])] = tree match {
case Match(arg, List(cd @ CaseDef(CasePattern(bindings), EmptyTree, body))) => Some((body, bindings))
case _ => None
}
private object CasePattern {
def unapply(tree: Tree): Option[Map[Symbol, List[String]]] = tree match {
case Apply(MethodTypeTree(params), binds) => {
val exprs = params.zip(binds) map {
case (p, CasePattern(inners)) => Some(inners map { case (sym, path) => (sym, p.name.toString +: path) })
case _ => None
}
if (exprs.forall(_.isDefined)) {
Some(exprs.flatten.flatten.toMap)
}
else
None
}
case Ident(_) | Bind(_, Ident(_)) => Some(Map(tree.symbol -> Nil))
case Bind(_, CasePattern(inners)) => Some(inners + (tree.symbol -> Nil))
case _ => None
}
}
private object MethodTypeTree {
def unapply(tree: Tree): Option[List[Symbol]] = tree match {
case _: TypeTree => tree.tpe match {
case MethodType(params, _) => Some(params)
case _ => None
}
case _ => None
}
}
}
private object TupleCtor {
def unapply(tree: Tree): Option[List[Tree]] = tree match {
case Apply(tpt@TypeApply(_, _), args) if isTupleTpe(tpt.tpe) => Some(args)
case _ => None
}
private def isTupleTpe(tpe: Type): Boolean = definitions.TupleClass.contains(tpe.typeSymbol)
}
}
protected def chkSelectors(udt: UDTDescriptor, sels: List[List[String]]): List[String] = {
sels flatMap { sel => chkSelector(udt, sel.head, sel.tail) }
}
protected def chkSelector(udt: UDTDescriptor, path: String, sel: List[String]): Option[String] = (udt, sel) match {
case (_, Nil) if udt.isPrimitiveProduct => None
case (_, Nil) => Some(path + ": " + udt.tpe + " is not a primitive or product of primitives")
case (_, field :: rest) => udt.select(field) match {
case None => Some("member " + field + " is not a case accessor of " + path + ": " + udt.tpe)
case Some(udt) => chkSelector(udt, path + "." + field, rest)
}
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
trait SerializeMethodGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with TreeGen[C] with SerializerGen[C] with Loggers[C] =>
import c.universe._
protected def mkSerialize(desc: UDTDescriptor, listImpls: Map[Int, Type]): List[Tree] = {
// val root = mkMethod("serialize", Flag.OVERRIDE | Flag.FINAL, List(("item", desc.tpe), ("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), definitions.UnitTpe, {
val root = mkMethod("serialize", Flag.FINAL, List(("item", desc.tpe), ("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), definitions.UnitTpe, {
val env = GenEnvironment(listImpls, "flat" + desc.id, false, true, true, true)
val stats = genSerialize(desc, Ident("item": TermName), Ident("record": TermName), env)
Block(stats.toList, mkUnit)
})
val aux = desc.getRecursiveRefs map { desc =>
mkMethod("serialize" + desc.id, Flag.PRIVATE | Flag.FINAL, List(("item", desc.tpe), ("record", typeOf[eu.stratosphere.pact.common.`type`.PactRecord])), definitions.UnitTpe, {
val env = GenEnvironment(listImpls, "boxed" + desc.id, true, false, false, true)
val stats = genSerialize(desc, Ident("item": TermName), Ident("record": TermName), env)
Block(stats.toList, mkUnit)
})
}
root +: aux.toList
}
private def genSerialize(desc: UDTDescriptor, source: Tree, target: Tree, env: GenEnvironment): Seq[Tree] = desc match {
case PrimitiveDescriptor(id, _, _, _) => {
val chk = env.mkChkIdx(id)
val ser = env.mkSetValue(id, source)
val set = env.mkSetField(id, target)
Seq(mkIf(chk, Block(List(ser), set)))
}
case BoxedPrimitiveDescriptor(id, tpe, _, _, _, unbox) => {
val chk = mkAnd(env.mkChkIdx(id), env.mkChkNotNull(source, tpe))
val ser = env.mkSetValue(id, unbox(source))
val set = env.mkSetField(id, target)
Seq(mkIf(chk, Block(List(ser), set)))
}
case desc @ ListDescriptor(id, tpe, iter, elem) => {
val chk = mkAnd(env.mkChkIdx(id), env.mkChkNotNull(source, tpe))
val upd = desc.getInnermostElem match {
case _: RecursiveDescriptor => Some(Apply(Select(target, "updateBinaryRepresenation"), List()))
case _ => None
}
val (init, list) = env.reentrant match {
// This is a bit conservative, but avoids runtime checks
// and/or even more specialized serialize() methods to
// track whether it's safe to reuse the list variable.
case true => {
val listTpe = env.listImpls(id)
val list = mkVal("list" + id, NoFlags, false, listTpe, New(TypeTree(listTpe), List(List())))
(list, Ident("list" + id: TermName))
}
case false => {
val clear = Apply(Select(env.mkSelectWrapper(id), "clear"), List())
(clear, env.mkSelectWrapper(id))
}
}
val body = genSerializeList(elem, iter(source), list, env.copy(chkNull = true))
val set = env.mkSetField(id, target, list)
val stats = (init +: body) :+ set
val updStats = upd ++ stats
Seq(mkIf(chk, Block(updStats.init.toList, updStats.last)))
}
case CaseClassDescriptor(_, tpe, _, _, getters) => {
val chk = env.mkChkNotNull(source, tpe)
val stats = getters filterNot { _.isBaseField } flatMap { case FieldAccessor(sym, _, _, _, desc) => genSerialize(desc, Select(source, sym), target, env.copy(chkNull = true)) }
stats match {
case Nil => Seq()
case _ => Seq(mkIf(chk, mkSingle(stats)))
}
}
case BaseClassDescriptor(id, tpe, Seq(tagField, baseFields @ _*), subTypes) => {
val chk = env.mkChkNotNull(source, tpe)
val fields = baseFields flatMap { (f => genSerialize(f.desc, Select(source, f.getter), target, env.copy(chkNull = true))) }
val cases = subTypes.zipWithIndex.toList map {
case (dSubType, i) => {
val pat = Bind("inst": TermName, Typed(Ident("_"), TypeTree(dSubType.tpe)))
val cast = None
val inst = Ident("inst": TermName)
// val (pat, cast, inst) = {
// val erasedTpe = mkErasedType(env.methodSym, dSubType.tpe)
//
// if (erasedTpe =:= dSubType.tpe) {
//
// val pat = Bind(newTermName("inst"), Typed(Ident("_"), TypeTree(dSubType.tpe)))
// (pat, None, Ident(newTermName("inst")))
//
// } else {
//
// // This avoids type erasure warnings in the generated pattern match
// val pat = Bind(newTermName("erasedInst"), Typed(Ident("_"), TypeTree(erasedTpe)))
// val cast = mkVal("inst", NoFlags, false, dSubType.tpe, mkAsInstanceOf(Ident("erasedInst"))(c.WeakTypeTag(dSubType.tpe)))
// val inst = Ident(cast.symbol)
// (pat, Some(cast), inst)
// }
// }
val tag = genSerialize(tagField.desc, c.literal(i).tree, target, env.copy(chkNull = false))
val code = genSerialize(dSubType, inst, target, env.copy(chkNull = false))
val body = (cast.toSeq ++ tag ++ code) :+ mkUnit
CaseDef(pat, EmptyTree, Block(body.init.toList, body.last))
}
}
Seq(mkIf(chk, Block(fields.toList,Match(source, cases))))
}
case RecursiveDescriptor(id, tpe, refId) => {
// Important: recursive types introduce re-entrant calls to serialize()
val chk = mkAnd(env.mkChkIdx(id), env.mkChkNotNull(source, tpe))
// Persist the outer record prior to recursing, since the call
// is going to reuse all the PactPrimitive wrappers that were
// needed *before* the recursion.
val updTgt = Apply(Select(target, "updateBinaryRepresenation"), List())
val rec = mkVal("record" + id, NoFlags, false, typeOf[eu.stratosphere.pact.common.`type`.PactRecord], New(TypeTree(typeOf[eu.stratosphere.pact.common.`type`.PactRecord]), List(List())))
val ser = env.mkCallSerialize(refId, source, Ident("record" + id: TermName))
// Persist the new inner record after recursing, since the
// current call is going to reuse all the PactPrimitive
// wrappers that are needed *after* the recursion.
val updRec = Apply(Select(Ident("record" + id: TermName), "updateBinaryRepresenation"), List())
val set = env.mkSetField(id, target, Ident("record" + id: TermName))
Seq(mkIf(chk, Block(List(updTgt, rec, ser, updRec), set)))
}
}
private def genSerializeList(elem: UDTDescriptor, iter: Tree, target: Tree, env: GenEnvironment): Seq[Tree] = {
val it = mkVal("it", NoFlags, false, mkIteratorOf(elem.tpe), iter)
val loop = mkWhile(Select(Ident("it": TermName), "hasNext")) {
val item = mkVal("item", NoFlags, false, elem.tpe, Select(Ident("it": TermName), "next"))
val (stats, value) = elem match {
case PrimitiveDescriptor(_, _, _, wrapper) => (Seq(), New(TypeTree(wrapper), List(List(Ident("item": TermName)))))
case BoxedPrimitiveDescriptor(_, _, _, wrapper, _, unbox) => (Seq(), New(TypeTree(wrapper), List(List(unbox(Ident("item": TermName))))))
case ListDescriptor(id, _, iter, innerElem) => {
val listTpe = env.listImpls(id)
val list = mkVal("list" + id, NoFlags, false, listTpe, New(TypeTree(listTpe), List(List())))
val body = genSerializeList(innerElem, iter(Ident("item": TermName)), Ident("list" + id: TermName), env)
(list +: body, Ident("list" + id: TermName))
}
case RecursiveDescriptor(id, tpe, refId) => {
val rec = mkVal("record" + id, NoFlags, false, typeOf[eu.stratosphere.pact.common.`type`.PactRecord], New(TypeTree(typeOf[eu.stratosphere.pact.common.`type`.PactRecord]), List(List())))
val ser = env.mkCallSerialize(refId, Ident("item": TermName), Ident("record" + id: TermName))
val updRec = Apply(Select(Ident("record" + id: TermName), "updateBinaryRepresenation"), List())
(Seq(rec, ser, updRec), Ident("record" + id: TermName))
}
case _ => {
val rec = mkVal("record", NoFlags, false, typeOf[eu.stratosphere.pact.common.`type`.PactRecord], New(TypeTree(typeOf[eu.stratosphere.pact.common.`type`.PactRecord]), List(List())))
val ser = genSerialize(elem, Ident("item": TermName), Ident("record": TermName), env.copy(idxPrefix = "boxed" + elem.id, chkIndex = false, chkNull = false))
val upd = Apply(Select(Ident("record": TermName), "updateBinaryRepresenation"), List())
((rec +: ser) :+ upd, Ident("record": TermName))
}
}
val chk = env.mkChkNotNull(Ident("item": TermName), elem.tpe)
val add = Apply(Select(target, "add"), List(value))
val addNull = Apply(Select(target, "add"), List(mkNull))
Block(List(item, mkIf(chk, mkSingle(stats :+ add))), addNull)
}
Seq(it, loop)
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
import eu.stratosphere.pact.common.`type`.PactRecord
import eu.stratosphere.scala.analysis.UDTSerializer
trait SerializerGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with UDTAnalyzer[C] with TreeGen[C] with SerializeMethodGen[C] with DeserializeMethodGen[C] with Loggers[C] =>
import c.universe._
def mkUdtSerializerClass[T: c.WeakTypeTag](name: String = "", creatorName: String = "createSerializer"): (ClassDef, Tree) = {
val desc = getUDTDescriptor(weakTypeOf[T])
desc match {
case UnsupportedDescriptor(_, _, errs) => {
val errorString = errs.mkString("\n")
c.abort(c.enclosingPosition, s"Error analyzing UDT ${weakTypeOf[T]}: $errorString")
}
case _ =>
}
val serName = newTypeName("UDTSerializerImpl" + name)
val ser = mkClass(serName, Flag.FINAL, List(weakTypeOf[UDTSerializer[T]]), {
val (listImpls, listImplTypes) = mkListImplClasses(desc)
val indexMapIter = Select(Ident("indexMap": TermName), "iterator": TermName)
val (fields1, inits1) = mkIndexes(desc.id, getIndexFields(desc).toList, false, indexMapIter)
val (fields2, inits2) = mkBoxedIndexes(desc)
val fields = fields1 ++ fields2
val init = inits1 ++ inits2 match {
case Nil => Nil
case inits => List(mkMethod("init", Flag.OVERRIDE | Flag.FINAL, List(), definitions.UnitTpe, Block(inits, mkUnit)))
}
val (wrapperFields, wrappers) = mkPactWrappers(desc, listImplTypes)
val mutableUdts = desc.flatten.toList flatMap {
case cc @ CaseClassDescriptor(_, _, true, _, _) => Some(cc)
case _ => None
} distinct
val mutableUdtInsts = mutableUdts map { u => mkMutableUdtInst(u) }
val helpers = listImpls ++ fields ++ wrapperFields ++ mutableUdtInsts ++ init
val ctor = mkMethod(nme.CONSTRUCTOR.toString(), NoFlags, List(("indexMap", typeOf[Array[Int]])), NoType, {
Block(List(mkSuperCall(List(Ident(newTermName("indexMap"))))), mkUnit)
})
// val methods = List(ctor)// ++ List(mkGetFieldIndex(desc)) //++ mkSerialize(desc, listImplTypes) ++ mkDeserialize(desc, listImplTypes)
val methods = List(ctor) ++ mkSerialize(desc, listImplTypes) ++ mkDeserialize(desc, listImplTypes)
helpers ++ methods
})
val (_, serTpe) = typeCheck(ser)
val createSerializer = mkMethod(creatorName, Flag.OVERRIDE, List(("indexMap", typeOf[Array[Int]])), NoType, {
Block(List(), mkCtorCall(serTpe, List(Ident(newTermName("indexMap")))))
})
(ser, createSerializer)
}
private def mkListImplClass[T <: eu.stratosphere.pact.common.`type`.Value: c.WeakTypeTag]: (Tree, Type) = {
val listImplName = c.fresh[TypeName]("PactListImpl")
val tpe = weakTypeOf[eu.stratosphere.pact.common.`type`.base.PactList[T]]
val listDef = mkClass(listImplName, Flag.FINAL, List(tpe), {
List(mkMethod(nme.CONSTRUCTOR.toString(), NoFlags, List(), NoType, Block(List(mkSuperCall()), mkUnit)))
})
typeCheck(listDef)
}
def mkListImplClasses(desc: UDTDescriptor): (List[Tree], Map[Int, Type]) = {
desc match {
case ListDescriptor(id, _, _, elem: ListDescriptor) => {
val (defs, tpes) = mkListImplClasses(elem)
val (listDef, listTpe) = mkListImplClass(c.WeakTypeTag(tpes(elem.id)))
(defs :+ listDef, tpes + (id -> listTpe))
}
case ListDescriptor(id, _, _, elem: PrimitiveDescriptor) => {
val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(elem.wrapper))
(List(classDef), Map(id -> tpe))
}
case ListDescriptor(id, _, _, elem: BoxedPrimitiveDescriptor) => {
val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(elem.wrapper))
(List(classDef), Map(id -> tpe))
}
case ListDescriptor(id, _, _, elem) => {
val (classDefs, tpes) = mkListImplClasses(elem)
val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(typeOf[eu.stratosphere.pact.common.`type`.PactRecord]))
(classDefs :+ classDef, tpes + (id -> tpe))
}
case BaseClassDescriptor(_, _, getters, subTypes) => {
val (defs, tpes) = getters.foldLeft((List[Tree](), Map[Int, Type]())) { (result, f) =>
val (defs, tpes) = result
val (newDefs, newTpes) = mkListImplClasses(f.desc)
(defs ++ newDefs, tpes ++ newTpes)
}
val (subDefs, subTpes) = subTypes.foldLeft((List[Tree](), Map[Int, Type]())) { (result, s) =>
val (defs, tpes) = result
val (innerDefs, innerTpes) = mkListImplClasses(s)
(defs ++ innerDefs, tpes ++ innerTpes)
}
(defs ++ subDefs, tpes ++ subTpes)
}
case CaseClassDescriptor(_, _, _, _, getters) => {
getters.foldLeft((List[Tree](), Map[Int, Type]())) { (result, f) =>
val (defs, tpes) = result
val (newDefs, newTpes) = mkListImplClasses(f.desc)
(defs ++ newDefs, tpes ++ newTpes)
}
}
case _ => {
(List[Tree](), Map[Int, Type]())
}
}
}
private def mkIndexes(descId: Int, descFields: List[UDTDescriptor], boxed: Boolean, indexMapIter: Tree): (List[Tree], List[Tree]) = {
val prefix = (if (boxed) "boxed" else "flat") + descId
val iterName = prefix + "Iter"
val iter = mkVal(iterName, Flag.PRIVATE, true, mkIteratorOf(definitions.IntTpe), indexMapIter)
val fieldsAndInits = descFields map {
case d => {
val next = Apply(Select(Ident(iterName: TermName), "next": TermName), Nil)
val idxField = mkVal(prefix + "Idx" + d.id, Flag.PRIVATE, false, definitions.IntTpe, next)
(List(idxField), Nil)
}
}
val (fields, inits) = fieldsAndInits.unzip
(iter +: fields.flatten, inits.flatten)
}
protected def getIndexFields(desc: UDTDescriptor): Seq[UDTDescriptor] = desc match {
// Flatten product types
case CaseClassDescriptor(_, _, _, _, getters) => getters filterNot { _.isBaseField } flatMap { f => getIndexFields(f.desc) }
// TODO: Rather than laying out subclass fields sequentially, just reserve enough fields for the largest subclass.
// This is tricky because subclasses can contain opaque descriptors, so we don't know how many fields we need until runtime.
case BaseClassDescriptor(id, _, getters, subTypes) => (getters flatMap { f => getIndexFields(f.desc) }) ++ (subTypes flatMap getIndexFields)
case _ => Seq(desc)
}
private def mkBoxedIndexes(desc: UDTDescriptor): (List[Tree], List[Tree]) = {
def getBoxedDescriptors(d: UDTDescriptor): Seq[UDTDescriptor] = d match {
case ListDescriptor(_, _, _, elem: BaseClassDescriptor) => elem +: getBoxedDescriptors(elem)
case ListDescriptor(_, _, _, elem: CaseClassDescriptor) => elem +: getBoxedDescriptors(elem)
case ListDescriptor(_, _, _, elem) => getBoxedDescriptors(elem)
case CaseClassDescriptor(_, _, _, _, getters) => getters filterNot { _.isBaseField } flatMap { f => getBoxedDescriptors(f.desc) }
case BaseClassDescriptor(id, _, getters, subTypes) => (getters flatMap { f => getBoxedDescriptors(f.desc) }) ++ (subTypes flatMap getBoxedDescriptors)
case RecursiveDescriptor(_, _, refId) => desc.findById(refId).map(_.mkRoot).toSeq
case _ => Seq()
}
val fieldsAndInits = getBoxedDescriptors(desc).distinct.toList flatMap { d =>
// the way this is done here is a relic from the support of OpaqueDescriptors
// there it was not just mkOne but actual differing numbers of fields
// retrieved from the opaque UDT descriptors
getIndexFields(d).toList match {
case Nil => None
case fields => {
val widths = fields map {
case _ => mkOne
}
val sum = widths.reduce { (s, i) => Apply(Select(s, "$plus": TermName), List(i)) }
val range = Apply(Select(Ident("scala": TermName), "Range": TermName), List(mkZero, sum))
Some(mkIndexes(d.id, fields, true, Select(range, "iterator": TermName)))
}
}
}
val (fields, inits) = fieldsAndInits.unzip
(fields.flatten, inits.flatten)
}
private def mkPactWrappers(desc: UDTDescriptor, listImpls: Map[Int, Type]): (List[Tree], List[(Int, Type)]) = {
def getFieldTypes(desc: UDTDescriptor): Seq[(Int, Type)] = desc match {
case PrimitiveDescriptor(id, _, _, wrapper) => Seq((id, wrapper))
case BoxedPrimitiveDescriptor(id, _, _, wrapper, _, _) => Seq((id, wrapper))
case d @ ListDescriptor(id, _, _, elem) => {
val listField = (id, listImpls(id))
val elemFields = d.getInnermostElem match {
case elem: CaseClassDescriptor => getFieldTypes(elem)
case elem: BaseClassDescriptor => getFieldTypes(elem)
case _ => Seq()
}
listField +: elemFields
}
case CaseClassDescriptor(_, _, _, _, getters) => getters filterNot { _.isBaseField } flatMap { f => getFieldTypes(f.desc) }
case BaseClassDescriptor(_, _, getters, subTypes) => (getters flatMap { f => getFieldTypes(f.desc) }) ++ (subTypes flatMap getFieldTypes)
case _ => Seq()
}
getFieldTypes(desc) toList match {
case Nil => (Nil, Nil)
case types =>
val fields = types map { case (id, tpe) => mkVar("w" + id, Flag.PRIVATE, true, tpe, mkCtorCall(tpe, List())) }
(fields, types)
}
}
private def mkMutableUdtInst(desc: CaseClassDescriptor): Tree = {
val args = desc.getters map {
case FieldAccessor(_, _, fTpe, _, _) => {
mkDefault(fTpe)
}
}
val ctor = mkCtorCall(desc.tpe, args.toList)
mkVar("mutableUdtInst" + desc.id, Flag.PRIVATE, true, desc.tpe, ctor)
}
private def mkGetFieldIndex(desc: UDTDescriptor): Tree = {
val env = GenEnvironment(Map(), "flat" + desc.id, false, true, true, true)
def mkCases(desc: UDTDescriptor, path: Seq[String]): Seq[(Seq[String], Tree)] = desc match {
case PrimitiveDescriptor(id, _, _, _) => Seq((path, mkList(List(env.mkSelectIdx(id)))))
case BoxedPrimitiveDescriptor(id, _, _, _, _, _) => Seq((path, mkList(List(env.mkSelectIdx(id)))))
case BaseClassDescriptor(_, _, Seq(tag, getters @ _*), _) => {
val tagCase = Seq((path :+ "getClass", mkList(List(env.mkSelectIdx(tag.desc.id)))))
val fieldCases = getters flatMap { f => mkCases(f.desc, path :+ f.getter.name.toString) }
tagCase ++ fieldCases
}
case CaseClassDescriptor(_, _, _, _, getters) => {
def fieldCases = getters flatMap { f => mkCases(f.desc, path :+ f.getter.name.toString) }
val allFieldsCase = desc match {
case _ if desc.isPrimitiveProduct => {
val nonRest = fieldCases filter { case (p, _) => p.size == path.size + 1 } map { _._2 }
Seq((path, nonRest.reduceLeft((z, f) => Apply(Select(z, "$plus$plus"), List(f)))))
}
case _ => Seq()
}
allFieldsCase ++ fieldCases
}
case _ => Seq()
}
def mkPat(path: Seq[String]): Tree = {
val seqUnapply = TypeApply(mkSelect("scala", "collection", "Seq", "unapplySeq"), List(TypeTree(typeOf[String])))
val fun = Apply(seqUnapply, List(Ident(nme.WILDCARD)))
val args = path map {
case null => Bind(newTermName("rest"), Star(Ident(newTermName("_"))))
case s => Literal(Constant(s))
}
UnApply(fun, args.toList)
}
mkMethod("getFieldIndex", Flag.FINAL, List(("selection", mkSeqOf(typeOf[String]))), mkListOf(typeOf[Int]), {
// mkMethod("getFieldIndex", Flag.OVERRIDE | Flag.FINAL, List(("selection", mkSeqOf(typeOf[String]))), mkListOf(typeOf[Int]), {
val cases = mkCases(desc, Seq()) map { case (path, idxs) => CaseDef(mkPat(path), EmptyTree, idxs) }
// val errCase = CaseDef(Ident("_"), EmptyTree, Apply(Ident(newTermName("println")), List(Ident("selection"))))
val errCase = CaseDef(Ident("_"), EmptyTree, (reify {throw new RuntimeException("Invalid selection")}).tree )
// Match(Ident("selection"), (cases :+ errCase).toList)
Match(Ident("selection"), List(errCase))
})
}
protected case class GenEnvironment(listImpls: Map[Int, Type], idxPrefix: String, reentrant: Boolean, allowRecycling: Boolean, chkIndex: Boolean, chkNull: Boolean) {
private def isNullable(tpe: Type) = typeOf[Null] <:< tpe && tpe <:< typeOf[AnyRef]
def mkChkNotNull(source: Tree, tpe: Type): Tree = if (isNullable(tpe) && chkNull) Apply(Select(source, "$bang$eq": TermName), List(mkNull)) else EmptyTree
def mkChkIdx(fieldId: Int): Tree = if (chkIndex) Apply(Select(mkSelectIdx(fieldId), "$greater$eq": TermName), List(mkZero)) else EmptyTree
def mkSelectIdx(fieldId: Int): Tree = Ident(newTermName(idxPrefix + "Idx" + fieldId))
def mkSelectSerializer(fieldId: Int): Tree = Ident(newTermName(idxPrefix + "Ser" + fieldId))
def mkSelectWrapper(fieldId: Int): Tree = Ident(newTermName("w" + fieldId))
def mkSelectMutableUdtInst(udtId: Int): Tree = Ident(newTermName("mutableUdtInst" + udtId))
def mkCallSetMutableField(udtId: Int, setter: Symbol, source: Tree): Tree = Apply(Select(mkSelectMutableUdtInst(udtId), setter), List(source))
def mkCallSerialize(refId: Int, source: Tree, target: Tree): Tree = Apply(Ident(newTermName("serialize" + refId)), List(source, target))
def mkCallDeserialize(refId: Int, source: Tree): Tree = Apply(Ident(newTermName("deserialize" + refId)), List(source))
def mkSetField(fieldId: Int, record: Tree): Tree = mkSetField(fieldId, record, mkSelectWrapper(fieldId))
def mkSetField(fieldId: Int, record: Tree, wrapper: Tree): Tree = Apply(Select(record, "setField": TermName), List(mkSelectIdx(fieldId), wrapper))
def mkGetFieldInto(fieldId: Int, record: Tree): Tree = mkGetFieldInto(fieldId, record, mkSelectWrapper(fieldId))
def mkGetFieldInto(fieldId: Int, record: Tree, wrapper: Tree): Tree = Apply(Select(record, "getFieldInto": TermName), List(mkSelectIdx(fieldId), wrapper))
def mkSetValue(fieldId: Int, value: Tree): Tree = mkSetValue(mkSelectWrapper(fieldId), value)
def mkSetValue(wrapper: Tree, value: Tree): Tree = Apply(Select(wrapper, "setValue": TermName), List(value))
def mkGetValue(fieldId: Int): Tree = mkGetValue(mkSelectWrapper(fieldId))
def mkGetValue(wrapper: Tree): Tree = Apply(Select(wrapper, "getValue": TermName), List())
def mkNotIsNull(fieldId: Int, record: Tree): Tree = Select(Apply(Select(record, "isNull": TermName), List(mkSelectIdx(fieldId))), "unary_$bang": TermName)
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with Loggers[C] =>
import c.universe._
def mkDefault(tpe: Type): Tree = {
import definitions._
tpe match {
case definitions.BooleanTpe => Literal(Constant(false))
case definitions.ByteTpe => Literal(Constant(0: Byte))
case definitions.CharTpe => Literal(Constant(0: Char))
case definitions.DoubleTpe => Literal(Constant(0: Double))
case definitions.FloatTpe => Literal(Constant(0: Float))
case definitions.IntTpe => Literal(Constant(0: Int))
case definitions.LongTpe => Literal(Constant(0: Long))
case definitions.ShortTpe => Literal(Constant(0: Short))
case definitions.UnitTpe => Literal(Constant(()))
case _ => Literal(Constant(null: Any))
}
}
def mkUnit = reify( () ).tree
def mkNull = reify( null ).tree
def mkZero = reify( 0 ).tree
def mkOne = reify( 1 ).tree
def mkAsInstanceOf[T: c.WeakTypeTag](source: Tree): Tree = reify(c.Expr(source).splice.asInstanceOf[T]).tree
def maybeMkAsInstanceOf[S: c.WeakTypeTag, T: c.WeakTypeTag](source: Tree): Tree = {
if (weakTypeOf[S] <:< weakTypeOf[T])
source
else
mkAsInstanceOf[T](source)
}
// def mkIdent(target: Symbol): Tree = Ident(target) setType target.tpe
def mkSelect(rootModule: String, path: String*): Tree = mkSelect(Ident(newTermName(rootModule)), path: _*)
def mkSelect(source: Tree, path: String*): Tree = path.foldLeft(source) { (ret, item) => Select(ret, newTermName(item)) }
def mkSelectSyms(source: Tree, path: Symbol*): Tree = path.foldLeft(source) { (ret, item) => Select(ret, item) }
def mkCall(root: Tree, path: String*)(args: List[Tree]) = Apply(mkSelect(root, path: _*), args)
def mkSeq(items: List[Tree]): Tree = Apply(mkSelect("scala", "collection", "Seq", "apply"), items)
def mkList(items: List[Tree]): Tree = Apply(mkSelect("scala", "collection", "immutable", "List", "apply"), items)
def mkVal(name: String, flags: FlagSet, transient: Boolean, valTpe: Type, value: Tree): Tree = {
ValDef(Modifiers(flags), newTermName(name), TypeTree(valTpe), value)
}
def mkVar(name: String, flags: FlagSet, transient: Boolean, valTpe: Type, value: Tree): Tree = {
mkVal(name, flags | Flag.MUTABLE, transient, valTpe, value)
}
def mkValAndGetter(name: String, flags: FlagSet, valTpe: Type, value: Tree): List[Tree] = {
val fieldName = name + " "
val valDef = mkVal(fieldName, Flag.PRIVATE, false, valTpe, value)
val defDef = mkMethod(name, flags, Nil, valTpe, Ident(newTermName(fieldName)))
List(valDef, defDef)
}
def mkVarAndLazyGetter(name: String, flags: FlagSet, valTpe: Type, value: Tree): (Tree, Tree) = {
val fieldName = name + " "
val field = mkVar(fieldName, NoFlags, false, valTpe, mkNull)
val fieldSel = Ident(newTermName(fieldName))
val getter = mkMethod(name, flags, Nil, valTpe, {
val eqeq = Select(fieldSel, newTermName("$eq$eq"))
val chk = Apply(eqeq, List(mkNull))
val init = Assign(fieldSel, value)
Block(List(If(chk, init, EmptyTree)), fieldSel)
})
(field, getter)
}
def mkIf(cond: Tree, bodyT: Tree): Tree = mkIf(cond, bodyT, EmptyTree)
def mkIf(cond: Tree, bodyT: Tree, bodyF: Tree): Tree = cond match {
case EmptyTree => bodyT
case _ => If(cond, bodyT, bodyF)
}
def mkSingle(stats: Seq[Tree]): Tree = stats match {
case Seq() => EmptyTree
case Seq(stat) => stat
case _ => Block(stats.init.toList, stats.last)
}
def mkAnd(cond1: Tree, cond2: Tree): Tree = cond1 match {
case EmptyTree => cond2
case _ => cond2 match {
case EmptyTree => cond1
case _ => reify(c.Expr[Boolean](cond1).splice && c.Expr[Boolean](cond2).splice).tree
}
}
def mkMethod(name: String, flags: FlagSet, args: List[(String, Type)], ret: Type, impl: Tree): Tree = {
val valParams = args map { case (name, tpe) => ValDef(Modifiers(Flag.PARAM), newTermName(name), TypeTree(tpe), EmptyTree) }
DefDef(Modifiers(flags), newTermName(name), Nil, List(valParams), TypeTree(ret), impl)
}
def mkClass(name: TypeName, flags: FlagSet, parents: List[Type], members: List[Tree]): ClassDef = {
val parentTypeTrees = parents map { TypeTree(_) }
val selfType = ValDef(Modifiers(), nme.WILDCARD, TypeTree(NoType), EmptyTree)
ClassDef(Modifiers(flags), name, Nil, Template(parentTypeTrees, selfType, members))
}
def mkThrow(tpe: Type, msg: Tree): Tree = Throw(Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), List(msg)))
// def mkThrow(tpe: Type, msg: Tree): Tree = Throw(New(TypeTree(tpe)), List(List(msg))))
def mkThrow(tpe: Type, msg: String): Tree = mkThrow(tpe, c.literal(msg).tree)
def mkThrow(msg: String): Tree = mkThrow(typeOf[java.lang.RuntimeException], msg)
implicit def tree2Ops[T <: Tree](tree: T) = new {
// copy of Tree.copyAttrs, since that method is private
// def copyAttrs(from: Tree): T = {
// tree.pos = from.pos
// tree.tpe = from.tpe
// if (tree.hasSymbol) tree.symbol = from.symbol
// tree
// }
def getSimpleClassName: String = {
val name = tree.getClass.getName
val idx = math.max(name.lastIndexOf('$'), name.lastIndexOf('.')) + 1
name.substring(idx)
}
}
def mkIteratorOf(tpe: Type) = {
def makeIt[T: c.WeakTypeTag] = weakTypeOf[Iterator[T]]
makeIt(c.WeakTypeTag(tpe))
}
def mkSeqOf(tpe: Type) = {
def makeIt[T: c.WeakTypeTag] = weakTypeOf[Seq[T]]
makeIt(c.WeakTypeTag(tpe))
}
def mkListOf(tpe: Type) = {
def makeIt[T: c.WeakTypeTag] = weakTypeOf[List[T]]
makeIt(c.WeakTypeTag(tpe))
}
def mkBuilderOf(elemTpe: Type, listTpe: Type) = {
def makeIt[ElemTpe: c.WeakTypeTag, ListTpe: c.WeakTypeTag] = weakTypeOf[scala.collection.mutable.Builder[ElemTpe, ListTpe]]
makeIt(c.WeakTypeTag(elemTpe), c.WeakTypeTag(listTpe))
}
def mkCanBuildFromOf(fromTpe: Type, elemTpe: Type, toTpe: Type) = {
def makeIt[From: c.WeakTypeTag, Elem: c.WeakTypeTag, To: c.WeakTypeTag] = weakTypeOf[scala.collection.generic.CanBuildFrom[From, Elem, To]]
makeIt(c.WeakTypeTag(fromTpe), c.WeakTypeTag(elemTpe), c.WeakTypeTag(toTpe))
}
def mkCtorCall(tpe: Type, args: List[Tree]) = Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), args)
def mkSuperCall(args: List[Tree] = List()) = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), args)
def mkWhile(cond: Tree)(body: Tree): Tree = {
val lblName = c.fresh[TermName]("while")
val jump = Apply(Ident(lblName), Nil)
val block = body match {
case Block(stats, expr) => Block(stats :+ expr, jump)
case _ => Block(List(body), jump)
}
LabelDef(lblName, Nil, If(cond, block, EmptyTree))
}
def typeCheck(classDef: ClassDef): (ClassDef, Type) = {
val block = Block(List(classDef), EmptyTree)
val checkedBlock = c.typeCheck(block)
// extract the class def from the block again
val checkedDef = checkedBlock match {
case Block((cls: ClassDef) :: Nil, _) => cls
}
val tpe = checkedDef.symbol.asClass.toType
(checkedDef, tpe)
}
def extractOneInputUdf(fun: Tree) = {
val (paramName, udfBody) = fun match {
case Function(List(param), body) => (param.name.toString, body)
case _ => c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
}
val uncheckedUdfBody = c.resetAllAttrs(udfBody)
(paramName, uncheckedUdfBody)
}
def extractTwoInputUdf(fun: Tree) = {
val (param1Name, param2Name, udfBody) = fun match {
case Function(List(param1, param2), body) => (param1.name.toString, param2.name.toString, body)
case _ => c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
}
val uncheckedUdfBody = c.resetAllAttrs(udfBody)
(param1Name, param2Name, uncheckedUdfBody)
}
def extractClass(block: Tree) = block match {
case Block((cls: ClassDef) :: Nil, _) => cls
case e => throw new RuntimeException("No class def at first position in block.")
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.collection.GenTraversableOnce
import scala.collection.mutable
import scala.reflect.macros.Context
import scala.util.DynamicVariable
import eu.stratosphere.pact.common.`type`.base.PactBoolean
import eu.stratosphere.pact.common.`type`.base.PactByte
import eu.stratosphere.pact.common.`type`.base.PactCharacter
import eu.stratosphere.pact.common.`type`.base.PactDouble
import eu.stratosphere.pact.common.`type`.base.PactFloat
import eu.stratosphere.pact.common.`type`.base.PactInteger
import eu.stratosphere.pact.common.`type`.base.PactString
import eu.stratosphere.pact.common.`type`.base.PactLong
import eu.stratosphere.pact.common.`type`.base.PactShort
import scala.Option.option2Iterable
trait UDTAnalyzer[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with Loggers[C] =>
import c.universe._
// This value is controlled by the udtRecycling compiler option
var enableMutableUDTs = false
private val mutableTypes = mutable.Set[Type]()
def getUDTDescriptor(tpe: Type): UDTDescriptor = (new UDTAnalyzerInstance with Logger).analyze(tpe)
private def normTpe(tpe: Type): Type = {
// TODO Figure out what the heck this does
// val currentThis = ThisType(localTyper.context.enclClass.owner)
// currentThis.baseClasses.foldLeft(tpe map { _.dealias }) { (tpe, base) => tpe.substThis(base, currentThis) }
tpe
}
private def typeArgs(tpe: Type) = tpe match { case TypeRef(_, _, args) => args }
private class UDTAnalyzerInstance { this: Logger =>
private val cache = new UDTAnalyzerCache()
def analyze(tpe: Type): UDTDescriptor = {
val normed = normTpe(tpe)
cache.getOrElseUpdate(normed) { id =>
normed match {
case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, normed, default, wrapper)
case BoxedPrimitiveType(default, wrapper, box, unbox) => BoxedPrimitiveDescriptor(id, normed, default, wrapper, box, unbox)
case ListType(elemTpe, iter) => analyzeList(id, normed, elemTpe, iter)
case CaseClassType() => analyzeCaseClass(id, normed)
case BaseClassType() => analyzeClassHierarchy(id, normed)
case _ => UnsupportedDescriptor(id, normed, Seq("Unsupported type " + normed))
}
}
}
private def analyzeList(id: Int, tpe: Type, elemTpe: Type, iter: Tree => Tree): UDTDescriptor = analyze(elemTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case desc => ListDescriptor(id, tpe, iter, desc)
}
private def analyzeClassHierarchy(id: Int, tpe: Type): UDTDescriptor = {
val tagField = {
val (intTpe, intDefault, intWrapper) = PrimitiveType.intPrimitive
FieldAccessor(NoSymbol, NoSymbol, NullaryMethodType(intTpe), true, PrimitiveDescriptor(cache.newId, intTpe, intDefault, intWrapper))
}
// c.info(c.enclosingPosition, "KNOWN SUBCLASSES: " + tpe.typeSymbol.asClass.knownDirectSubclasses.toList, true)
val subTypes = tpe.typeSymbol.asClass.knownDirectSubclasses.toList flatMap { d =>
val dTpe = // verbosely[Type] { dTpe => d.tpe + " <: " + tpe + " instantiated as " + dTpe + " (" + (if (dTpe <:< tpe) "Valid" else "Invalid") + " subtype)" }
{
val tArgs = (tpe.typeSymbol.asClass.typeParams, typeArgs(tpe)).zipped.toMap
val dArgs = d.asClass.typeParams map { dp =>
val tArg = tArgs.keySet.find { tp => dp == tp.typeSignature.asSeenFrom(d.typeSignature, tpe.typeSymbol).typeSymbol }
tArg map { tArgs(_) } getOrElse dp.typeSignature
}
normTpe(appliedType(d.asType.toType, dArgs))
}
// c.info(c.enclosingPosition, "dTpe: " + dTpe, true)
if (dTpe <:< tpe)
Some(analyze(dTpe))
else
None
}
// c.info(c.enclosingPosition, c.enclosingRun.units.size + " SUBTYPES: " + subTypes, true)
val errors = subTypes flatMap { _.findByType[UnsupportedDescriptor] }
// c.info(c.enclosingPosition, "ERROS: " + errors, true)
errors match {
case _ :: _ => UnsupportedDescriptor(id, tpe, errors flatMap { case UnsupportedDescriptor(_, subType, errs) => errs map { err => "Subtype " + subType + " - " + err } })
case Nil if subTypes.isEmpty => UnsupportedDescriptor(id, tpe, Seq("No instantiable subtypes found for base class"))
case Nil => {
val (tParams, tArgs) = tpe.typeSymbol.asClass.typeParams.zip(typeArgs(tpe)).unzip
val baseMembers = tpe.members filter { f => f.isMethod } filter { f => f.asMethod.isSetter } map {
f => (f, f.asMethod.setter, normTpe(f.asMethod.returnType))
}
val subMembers = subTypes map {
case BaseClassDescriptor(_, _, getters, _) => getters
case CaseClassDescriptor(_, _, _, _, getters) => getters
case _ => Seq()
}
val baseFields = baseMembers flatMap {
case (bGetter, bSetter, bTpe) => {
val accessors = subMembers map {
_ find { sf =>
sf.getter.name == bGetter.name && sf.tpe.termSymbol.asMethod.returnType <:< bTpe.termSymbol.asMethod.returnType
}
}
accessors.forall { _.isDefined } match {
case true => Some(FieldAccessor(bGetter, bSetter, bTpe, true, analyze(bTpe.termSymbol.asMethod.returnType)))
case false => None
}
}
}
def wireBaseFields(desc: UDTDescriptor): UDTDescriptor = {
def updateField(field: FieldAccessor) = {
baseFields find { bf => bf.getter.name == field.getter.name } match {
case Some(FieldAccessor(_, _, _, _, desc)) => field.copy(isBaseField = true, desc = desc)
case None => field
}
}
desc match {
case desc @ BaseClassDescriptor(_, _, getters, subTypes) => desc.copy(getters = getters map updateField, subTypes = subTypes map wireBaseFields)
case desc @ CaseClassDescriptor(_, _, _, _, getters) => desc.copy(getters = getters map updateField)
case _ => desc
}
}
//Debug.report("BaseClass " + tpe + " has shared fields: " + (baseFields.map { m => m.sym.name + ": " + m.tpe }))
BaseClassDescriptor(id, tpe, tagField +: (baseFields.toSeq), subTypes map wireBaseFields)
}
}
}
private def analyzeCaseClass(id: Int, tpe: Type): UDTDescriptor = {
tpe.baseClasses exists { bc => !(bc == tpe.typeSymbol) && bc.asClass.isCaseClass } match {
case true => UnsupportedDescriptor(id, tpe, Seq("Case-to-case inheritance is not supported."))
case false => {
val ctors = tpe.declarations collect {
case m: MethodSymbol if m.isPrimaryConstructor => m
}
ctors match {
case c1 :: c2 :: _ => UnsupportedDescriptor(id, tpe, Seq("Multiple constructors found, this is not supported."))
case c :: Nil => {
val caseFields = c.paramss.flatten.map {
sym =>
{
val methodSym = tpe.member(sym.name).asMethod
(methodSym.getter, methodSym.setter, methodSym.returnType.asSeenFrom(tpe, tpe.typeSymbol))
}
}
val fields = caseFields map {
case (fgetter, fsetter, fTpe) => FieldAccessor(fgetter, fsetter, fTpe, false, analyze(fTpe))
}
val mutable = maybeVerbosely[Boolean](m => m && mutableTypes.add(tpe))(_ => "Detected recyclable type: " + tpe) {
enableMutableUDTs && (fields forall { f => f.setter != NoSymbol })
}
fields filter { _.desc.isInstanceOf[UnsupportedDescriptor] } match {
case errs @ _ :: _ => {
val msgs = errs flatMap { f =>
(f: @unchecked) match {
case FieldAccessor(fgetter, _, _, _, UnsupportedDescriptor(_, fTpe, errors)) => errors map { err => "Field " + fgetter.name + ": " + fTpe + " - " + err }
}
}
UnsupportedDescriptor(id, tpe, msgs)
}
case Nil => CaseClassDescriptor(id, tpe, mutable, c, fields.toSeq)
}
}
}
}
}
}
private object PrimitiveType {
def intPrimitive: (Type, Literal, Type) = {
val (d, w) = primitives(definitions.IntClass)
(definitions.IntTpe, d, w)
}
def unapply(tpe: Type): Option[(Literal, Type)] = primitives.get(tpe.typeSymbol)
}
private object BoxedPrimitiveType {
def unapply(tpe: Type): Option[(Literal, Type, Tree => Tree, Tree => Tree)] = boxedPrimitives.get(tpe.typeSymbol)
}
private object ListType {
def unapply(tpe: Type): Option[(Type, Tree => Tree)] = tpe match {
case ArrayType(elemTpe) => {
val iter = { source: Tree =>
Apply(Select(source, "iterator": TermName), Nil)
}
Some(elemTpe, iter)
}
case TraversableType(elemTpe) => {
val iter = { source: Tree => Select(source, "toIterator": TermName) }
Some(elemTpe, iter)
}
case _ => None
}
private object ArrayType {
def unapply(tpe: Type): Option[Type] = tpe match {
case TypeRef(_, _, elemTpe :: Nil) if tpe <:< typeOf[Array[_]] => Some(elemTpe)
case _ => None
}
}
private object TraversableType {
def unapply(tpe: Type): Option[Type] = tpe match {
case _ if tpe <:< typeOf[GenTraversableOnce[_]] => {
// val abstrElemTpe = genTraversableOnceClass.typeConstructor.typeParams.head.tpe
// val elemTpe = abstrElemTpe.asSeenFrom(tpe, genTraversableOnceClass)
// Some(elemTpe)
// TODO make sure this shit works as it should
tpe match {
case TypeRef(_, _, elemTpe :: Nil) => Some(elemTpe.asSeenFrom(tpe, tpe.typeSymbol))
}
}
case _ => None
}
}
}
private object CaseClassType {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
}
private object BaseClassType {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isAbstractClass && tpe.typeSymbol.asClass.isSealed
}
private class UDTAnalyzerCache {
private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map())
private val idGen = new Counter()
def newId = idGen.next
def getOrElseUpdate(tpe: Type)(orElse: Int => UDTDescriptor): UDTDescriptor = {
val id = idGen.next
val cache = caches.value
cache.get(tpe) map { _.copy(id = id) } getOrElse {
val ref = RecursiveDescriptor(id, tpe, id)
caches.withValue(cache + (tpe -> ref)) {
orElse(id)
}
}
}
}
}
lazy val primitives = Map[Symbol, (Literal, Type)](
definitions.BooleanClass -> (Literal(Constant(false)), typeOf[PactBoolean]),
definitions.ByteClass -> (Literal(Constant(0: Byte)), typeOf[PactByte]),
definitions.CharClass -> (Literal(Constant(0: Char)), typeOf[PactCharacter]),
definitions.DoubleClass -> (Literal(Constant(0: Double)), typeOf[PactDouble]),
definitions.FloatClass -> (Literal(Constant(0: Float)), typeOf[PactFloat]),
definitions.IntClass -> (Literal(Constant(0: Int)), typeOf[PactInteger]),
definitions.LongClass -> (Literal(Constant(0: Long)), typeOf[PactLong]),
definitions.ShortClass -> (Literal(Constant(0: Short)), typeOf[PactShort]),
definitions.StringClass -> (Literal(Constant(null: String)), typeOf[PactString]))
lazy val boxedPrimitives = {
def getBoxInfo(prim: Symbol, primName: String, boxName: String) = {
val (default, wrapper) = primitives(prim)
val box = { t: Tree =>
Apply(Select(Select(Ident(newTermName("scala")), newTermName("Predef")), newTermName(primName + "2" + boxName)), List(t))
}
val unbox = { t: Tree =>
Apply(Select(Select(Ident(newTermName("scala")), newTermName("Predef")), newTermName(boxName + "2" + primName)), List(t))
}
(default, wrapper, box, unbox)
}
Map(
typeOf[java.lang.Boolean].typeSymbol -> getBoxInfo(definitions.BooleanClass, "boolean", "Boolean"),
typeOf[java.lang.Byte].typeSymbol -> getBoxInfo(definitions.ByteClass, "byte", "Byte"),
typeOf[java.lang.Character].typeSymbol -> getBoxInfo(definitions.CharClass, "char", "Character"),
typeOf[java.lang.Double].typeSymbol -> getBoxInfo(definitions.DoubleClass, "double", "Double"),
typeOf[java.lang.Float].typeSymbol -> getBoxInfo(definitions.FloatClass, "float", "Float"),
typeOf[java.lang.Integer].typeSymbol -> getBoxInfo(definitions.IntClass, "int", "Integer"),
typeOf[java.lang.Long].typeSymbol -> getBoxInfo(definitions.LongClass, "long", "Long"),
typeOf[java.lang.Short].typeSymbol -> getBoxInfo(definitions.ShortClass, "short", "Short"))
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
import scala.reflect.classTag
import scala.reflect.ClassTag
import scala.Option.option2Iterable
trait UDTDescriptors[C <: Context] { this: MacroContextHolder[C] =>
import c.universe._
abstract sealed class UDTDescriptor {
val id: Int
val tpe: Type
val isPrimitiveProduct: Boolean = false
def mkRoot: UDTDescriptor = this
def flatten: Seq[UDTDescriptor]
def getters: Seq[FieldAccessor] = Seq()
def select(member: String): Option[UDTDescriptor] = getters find { _.getter.name.toString == member } map { _.desc }
def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => Seq(Some(this))
case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None)
case Some(d : FieldAccessor) => d.desc.select(tail)
}
}
def findById(id: Int): Option[UDTDescriptor] = flatten.find { _.id == id }
def findByType[T <: UDTDescriptor: ClassTag]: Seq[T] = {
val clazz = classTag[T].runtimeClass
flatten filter { item => clazz.isAssignableFrom(item.getClass) } map { _.asInstanceOf[T] }
}
def getRecursiveRefs: Seq[UDTDescriptor] = findByType[RecursiveDescriptor] flatMap { rd => findById(rd.refId) } map { _.mkRoot } distinct
}
case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor {
override def flatten = Seq(this)
}
case class PrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type) extends UDTDescriptor {
override val isPrimitiveProduct = true
override def flatten = Seq(this)
}
case class BoxedPrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type, box: Tree => Tree, unbox: Tree => Tree) extends UDTDescriptor {
override val isPrimitiveProduct = true
override def flatten = Seq(this)
override def hashCode() = (id, tpe, default, wrapper, "BoxedPrimitiveDescriptor").hashCode()
override def equals(that: Any) = that match {
case BoxedPrimitiveDescriptor(thatId, thatTpe, thatDefault, thatWrapper, _, _) => (id, tpe, default, wrapper).equals(thatId, thatTpe, thatDefault, thatWrapper)
case _ => false
}
}
case class ListDescriptor(id: Int, tpe: Type, iter: Tree => Tree, elem: UDTDescriptor) extends UDTDescriptor {
override def flatten = this +: elem.flatten
def getInnermostElem: UDTDescriptor = elem match {
case list: ListDescriptor => list.getInnermostElem
case _ => elem
}
override def hashCode() = (id, tpe, elem).hashCode()
override def equals(that: Any) = that match {
case that @ ListDescriptor(thatId, thatTpe, _, thatElem) => (id, tpe, elem).equals((thatId, thatTpe, thatElem))
case _ => false
}
}
case class BaseClassDescriptor(id: Int, tpe: Type, override val getters: Seq[FieldAccessor], subTypes: Seq[UDTDescriptor]) extends UDTDescriptor {
override def flatten = this +: ((getters flatMap { _.desc.flatten }) ++ (subTypes flatMap { _.flatten }))
override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => getters flatMap { g => g.desc.select(Nil) }
case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None)
case Some(d : FieldAccessor) => d.desc.select(tail)
}
}
}
case class CaseClassDescriptor(id: Int, tpe: Type, mutable: Boolean, ctor: Symbol, override val getters: Seq[FieldAccessor]) extends UDTDescriptor {
override val isPrimitiveProduct = !getters.isEmpty && getters.forall(_.desc.isPrimitiveProduct)
override def mkRoot = this.copy(getters = getters map { _.copy(isBaseField = false) })
override def flatten = this +: (getters flatMap { _.desc.flatten })
// Hack: ignore the ctorTpe, since two Type instances representing
// the same ctor function type don't appear to be considered equal.
// Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
override def hashCode = (id, tpe, ctor, getters).hashCode
override def equals(that: Any) = that match {
case CaseClassDescriptor(thatId, thatTpe, thatMutable, thatCtor, thatGetters) => (id, tpe, mutable, ctor, getters).equals(thatId, thatTpe, thatMutable, thatCtor, thatGetters)
case _ => false
}
override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => getters flatMap { g => g.desc.select(Nil) }
case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None)
case Some(d : FieldAccessor) => d.desc.select(tail)
}
}
}
case class FieldAccessor(getter: Symbol, setter: Symbol, tpe: Type, isBaseField: Boolean, desc: UDTDescriptor)
case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor {
override def flatten = Seq(this)
}
}
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala.codegen
import scala.reflect.macros.Context
import eu.stratosphere.scala.analysis.UDT
import eu.stratosphere.pact.common.`type`.base.PactList
import eu.stratosphere.pact.common.`type`.PactRecord
trait UDTGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with UDTAnalyzer[C] with TreeGen[C] with SerializerGen[C] with SerializeMethodGen[C] with DeserializeMethodGen[C] with Loggers[C] =>
import c.universe._
def mkUdtClass[T: c.WeakTypeTag](): (ClassDef, Tree) = {
val desc = getUDTDescriptor(weakTypeOf[T])
val udtName = c.fresh[TypeName]("GeneratedUDTDescriptor")
val udt = mkClass(udtName, Flag.FINAL, List(weakTypeOf[UDT[T]]), {
val (ser, createSer) = mkUdtSerializerClass[T](creatorName = "createSerializer")
val ctor = mkMethod(nme.CONSTRUCTOR.toString(), NoFlags, List(), NoType, {
Block(List(mkSuperCall(Nil)), mkUnit)
})
List(ser, createSer, ctor, mkFieldTypes(desc), mkUDTIdToIndexMap(desc))
})
val (_, udtTpe) = typeCheck(udt)
(udt, mkCtorCall(udtTpe, Nil))
}
private def mkFieldTypes(desc: UDTDescriptor): Tree = {
mkVal("fieldTypes", Flag.OVERRIDE | Flag.FINAL, false, typeOf[Array[Class[_ <: eu.stratosphere.pact.common.`type`.Value]]], {
val fieldTypes = getIndexFields(desc).toList map {
case PrimitiveDescriptor(_, _, _, wrapper) => Literal(Constant(wrapper))
case BoxedPrimitiveDescriptor(_, _, _, wrapper, _, _) => Literal(Constant(wrapper))
case ListDescriptor(_, _, _, _) => Literal(Constant(typeOf[PactList[eu.stratosphere.pact.common.`type`.Value]]))
// Box inner instances of recursive types
case RecursiveDescriptor(_, _, _) => Literal(Constant(typeOf[PactRecord]))
case BaseClassDescriptor(_, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
case CaseClassDescriptor(_, _, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
case UnsupportedDescriptor(_, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
}
Apply(Select(Select(Ident("scala": TermName), "Array": TermName), "apply": TermName), fieldTypes)
})
}
private def mkUDTIdToIndexMap(desc: UDTDescriptor): Tree = {
mkVal("udtIdMap", Flag.OVERRIDE | Flag.FINAL, false, typeOf[Map[Int, Int]], {
val fieldIds = getIndexFields(desc).toList map {
case PrimitiveDescriptor(id, _, _, _) => Literal(Constant(id))
case BoxedPrimitiveDescriptor(id, _, _, _, _, _) => Literal(Constant(id))
case ListDescriptor(id, _, _, _) => Literal(Constant(id))
// Box inner instances of recursive types
case RecursiveDescriptor(id, _, _) => Literal(Constant(id))
case BaseClassDescriptor(_, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
case CaseClassDescriptor(_, _, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
case UnsupportedDescriptor(_, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
}
val fields = fieldIds.zipWithIndex map { case (id, idx) =>
val idExpr = c.Expr[Int](id)
val idxExpr = c.Expr[Int](Literal(Constant(idx)))
reify { (idExpr.splice, idxExpr.splice) }.tree
}
Apply(Select(Select(Select(Ident("scala": TermName), "Predef": TermName), "Map": TermName), "apply": TermName), fields)
})
}
}
\ No newline at end of file
/**
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package eu.stratosphere.scala
import scala.collection.TraversableOnce
import eu.stratosphere.scala.analysis._
import eu.stratosphere.pact.generic.contract.Contract
package object operators {
// implicit def funToRepeat[SolutionItem: UDT](stepFunction: DataStream[SolutionItem] => DataStream[SolutionItem]) = new RepeatOperator(stepFunction)
// implicit def funToIterate[SolutionItem, DeltaItem](s0: DataStream[SolutionItem]) = new IterateOperator(s0.contract)
// implicit def funToWorksetIterate[SolutionItem: UDT, WorksetItem: UDT](stepFunction: (DataStream[SolutionItem], DataStream[WorksetItem]) => (DataStream[SolutionItem], DataStream[WorksetItem])) = new WorksetIterateOperator(stepFunction)
implicit def traversableToIterator[T](i: TraversableOnce[T]): Iterator[T] = i.toIterator
implicit def optionToIterator[T](opt: Option[T]): Iterator[T] = opt.iterator
implicit def arrayToIterator[T](arr: Array[T]): Iterator[T] = arr.iterator
}
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册