提交 dd4cc579 编写于 作者: A Andrew Scheidecker 提交者: Matt Witherspoon

Support parsing of references to types declared later in a WAST module

上级 9dcc9bb5
......@@ -167,45 +167,60 @@ namespace WAST
return FunctionType::get(ret,parameters);
}
IndexedFunctionType parseFunctionTypeRef(ModuleParseState& state,NameToIndexMap& outLocalNameToIndexMap,std::vector<std::string>& outLocalDisassemblyNames)
UnresolvedFunctionType parseFunctionTypeRefAndOrDecl(ModuleParseState& state,NameToIndexMap& outLocalNameToIndexMap,std::vector<std::string>& outLocalDisassemblyNames)
{
// Parse an optional function type reference.
const Token* typeReferenceToken = state.nextToken;
IndexedFunctionType referencedFunctionType = {UINT32_MAX};
Reference functionTypeRef;
if(state.nextToken[0].type == t_leftParenthesis
&& state.nextToken[1].type == t_type)
{
// Parse a reference by name or index to some type in the module's type table.
parseParenthesized(state,[&]
{
require(state,t_type);
referencedFunctionType.index = parseAndResolveNameOrIndexRef(state,state.typeNameToIndexMap,state.module.types.size(),"type");
if(!tryParseNameOrIndexRef(state,functionTypeRef))
{
parseErrorf(state,state.nextToken,"expected type name or index");
throw RecoverParseException();
}
});
}
// Parse the explicit function parameters and result type.
const FunctionType* directFunctionType = parseFunctionType(state,outLocalNameToIndexMap,outLocalDisassemblyNames);
const bool hasNoDirectType = directFunctionType == FunctionType::get();
const FunctionType* explicitFunctionType = parseFunctionType(state,outLocalNameToIndexMap,outLocalDisassemblyNames);
UnresolvedFunctionType result;
result.reference = functionTypeRef;
result.explicitType = explicitFunctionType;
return result;
}
// Validate that if the function definition has both a type reference, and explicit parameter/result type declarations, that they match.
IndexedFunctionType functionType;
if(referencedFunctionType.index != UINT32_MAX && hasNoDirectType)
IndexedFunctionType resolveFunctionType(ModuleParseState& state,const UnresolvedFunctionType& unresolvedType)
{
if(!unresolvedType.reference)
{
functionType = referencedFunctionType;
return getUniqueFunctionTypeIndex(state,unresolvedType.explicitType);
}
else
{
functionType = getUniqueFunctionTypeIndex(state,directFunctionType);
if(referencedFunctionType.index != UINT32_MAX && state.module.types[referencedFunctionType.index] != directFunctionType)
// Resolve the referenced type.
const U32 referencedFunctionTypeIndex = resolveRef(state,state.typeNameToIndexMap,state.module.types.size(),unresolvedType.reference);
// Validate that if the function definition has both a type reference and explicit parameter/result type declarations, they match.
const bool hasExplicitParametersOrResultType = unresolvedType.explicitType != FunctionType::get();
if(hasExplicitParametersOrResultType)
{
parseErrorf(state,typeReferenceToken,"referenced function type (%s) does not match declared parameters and results (%s)",
asString(state.module.types[referencedFunctionType.index]).c_str(),
asString(directFunctionType).c_str()
);
if(referencedFunctionTypeIndex != UINT32_MAX
&& state.module.types[referencedFunctionTypeIndex] != unresolvedType.explicitType)
{
parseErrorf(state,unresolvedType.reference.token,"referenced function type (%s) does not match declared parameters and results (%s)",
asString(state.module.types[referencedFunctionTypeIndex]).c_str(),
asString(unresolvedType.explicitType).c_str()
);
}
}
}
return functionType;
return {referencedFunctionTypeIndex};
}
}
IndexedFunctionType getUniqueFunctionTypeIndex(ModuleParseState& state,const FunctionType* functionType)
......
......@@ -104,6 +104,13 @@ namespace WAST
operator bool() const { return type != Type::invalid; }
};
// Represents a function type, either as an unresolved name/index, or as an explicit type, or both.
struct UnresolvedFunctionType
{
Reference reference;
const IR::FunctionType* explicitType;
};
// State associated with parsing a module.
struct ModuleParseState : ParseState
{
......@@ -119,6 +126,9 @@ namespace WAST
IR::DisassemblyNames disassemblyNames;
// Thunks that are called after parsing all types.
std::vector<std::function<void(ModuleParseState&)>> postTypeCallbacks;
// Thunks that are called after parsing all declarations.
std::vector<std::function<void(ModuleParseState&)>> postDeclarationCallbacks;
......@@ -141,7 +151,8 @@ namespace WAST
IR::ValueType parseValueType(ParseState& state);
const IR::FunctionType* parseFunctionType(ModuleParseState& state,NameToIndexMap& outLocalNameToIndexMap,std::vector<std::string>& outLocalDisassemblyNames);
IR::IndexedFunctionType parseFunctionTypeRef(ModuleParseState& state,NameToIndexMap& outLocalNameToIndexMap,std::vector<std::string>& outLocalDisassemblyNames);
UnresolvedFunctionType parseFunctionTypeRefAndOrDecl(ModuleParseState& state,NameToIndexMap& outLocalNameToIndexMap,std::vector<std::string>& outLocalDisassemblyNames);
IR::IndexedFunctionType resolveFunctionType(ModuleParseState& state,const UnresolvedFunctionType& unresolvedType);
IR::IndexedFunctionType getUniqueFunctionTypeIndex(ModuleParseState& state,const IR::FunctionType* functionType);
// Literal parsing.
......
......@@ -546,90 +546,79 @@ namespace WAST
std::vector<std::string>* localDisassemblyNames = new std::vector<std::string>;
NameToIndexMap* localNameToIndexMap = new NameToIndexMap();
// Parse an optional function type reference.
const Token* typeReferenceToken = state.nextToken;
IndexedFunctionType referencedFunctionType = {UINT32_MAX};
if(state.nextToken[0].type == t_leftParenthesis
&& state.nextToken[1].type == t_type)
{
referencedFunctionType = parseFunctionTypeRef(state,*localNameToIndexMap,*localDisassemblyNames);
}
// Parse the explicit function parameters and result type.
const FunctionType* directFunctionType = parseFunctionType(state,*localNameToIndexMap,*localDisassemblyNames);
const bool hasNoDirectType = directFunctionType == FunctionType::get();
// Validate that if the function definition has both a type reference, and explicit parameter/result type declarations, that they match.
IndexedFunctionType functionType;
if(referencedFunctionType.index != UINT32_MAX && hasNoDirectType)
{
functionType = referencedFunctionType;
}
else
{
functionType = getUniqueFunctionTypeIndex(state,directFunctionType);
if(referencedFunctionType.index != UINT32_MAX && state.module.types[referencedFunctionType.index] != directFunctionType)
{
parseErrorf(state,typeReferenceToken,"referenced function type (%s) does not match declared parameters and results (%s)",
asString(state.module.types[referencedFunctionType.index]).c_str(),
asString(directFunctionType).c_str()
);
}
}
// Parse the function's local variables.
std::vector<ValueType> nonParameterLocalTypes;
while(tryParseParenthesizedTagged(state,t_local,[&]
{
Name localName;
if(tryParseName(state,localName))
{
bindName(state,*localNameToIndexMap,localName,directFunctionType->parameters.size() + nonParameterLocalTypes.size());
localDisassemblyNames->push_back(localName.getString());
nonParameterLocalTypes.push_back(parseValueType(state));
}
else
{
while(state.nextToken->type != t_rightParenthesis)
{
localDisassemblyNames->push_back(std::string());
nonParameterLocalTypes.push_back(parseValueType(state));
};
}
}));
// Defer parsing the body of the function until after all declarations have been parsed.
// Parse the function type, as a reference or explicit declaration.
const UnresolvedFunctionType unresolvedFunctionType = parseFunctionTypeRefAndOrDecl(state,*localNameToIndexMap,*localDisassemblyNames);
// Defer resolving the function type until all type declarations have been parsed.
const Uptr functionIndex = state.module.functions.size();
const Uptr functionDefIndex = state.module.functions.defs.size();
const Token* firstBodyToken = state.nextToken;
state.postDeclarationCallbacks.push_back([functionIndex,functionDefIndex,firstBodyToken,localNameToIndexMap,localDisassemblyNames](ModuleParseState& state)
state.postTypeCallbacks.push_back(
[functionIndex,functionDefIndex,firstBodyToken,localNameToIndexMap,localDisassemblyNames,unresolvedFunctionType]
(ModuleParseState& state)
{
FunctionParseState functionState(state,localNameToIndexMap,firstBodyToken,state.module.functions.defs[functionDefIndex]);
try
// Resolve the function type and set it on the FunctionDef.
const IndexedFunctionType functionTypeIndex = resolveFunctionType(state,unresolvedFunctionType);
state.module.functions.defs[functionDefIndex].type = functionTypeIndex;
// Defer parsing the body of the function until all function types have been resolved.
state.postDeclarationCallbacks.push_back(
[functionIndex,functionDefIndex,firstBodyToken,localNameToIndexMap,localDisassemblyNames,functionTypeIndex]
(ModuleParseState& state)
{
parseInstrSequence(functionState);
if(!functionState.errors.size())
FunctionDef& functionDef = state.module.functions.defs[functionDefIndex];
const FunctionType* functionType = functionTypeIndex.index == UINT32_MAX
? FunctionType::get()
: state.module.types[functionTypeIndex.index];
// Parse the function's local variables.
ParseState localParseState(state.string,state.lineInfo,state.errors,firstBodyToken);
while(tryParseParenthesizedTagged(localParseState,t_local,[&]
{
functionState.validatingCodeStream.end();
functionState.validatingCodeStream.finishValidation();
}
}
catch(ValidationException exception)
{
parseErrorf(state,firstBodyToken,"%s",exception.message.c_str());
}
catch(RecoverParseException) {}
catch(FatalParseException) {}
Name localName;
if(tryParseName(localParseState,localName))
{
bindName(localParseState,*localNameToIndexMap,localName,functionType->parameters.size() + functionDef.nonParameterLocalTypes.size());
localDisassemblyNames->push_back(localName.getString());
functionDef.nonParameterLocalTypes.push_back(parseValueType(localParseState));
}
else
{
while(localParseState.nextToken->type != t_rightParenthesis)
{
localDisassemblyNames->push_back(std::string());
functionDef.nonParameterLocalTypes.push_back(parseValueType(localParseState));
};
}
}));
state.disassemblyNames.functions[functionIndex].locals = std::move(*localDisassemblyNames);
delete localDisassemblyNames;
state.module.functions.defs[functionDefIndex].code = std::move(functionState.codeByteStream.getBytes());
state.disassemblyNames.functions[functionIndex].locals = std::move(*localDisassemblyNames);
delete localDisassemblyNames;
// Parse the function's code.
FunctionParseState functionState(state,localNameToIndexMap,localParseState.nextToken,functionDef);
try
{
parseInstrSequence(functionState);
if(!functionState.errors.size())
{
functionState.validatingCodeStream.end();
functionState.validatingCodeStream.finishValidation();
}
}
catch(ValidationException exception)
{
parseErrorf(state,firstBodyToken,"%s",exception.message.c_str());
}
catch(RecoverParseException) {}
catch(FatalParseException) {}
functionDef.code = std::move(functionState.codeByteStream.getBytes());
});
});
// Continue parsing after the closing parenthesis.
findClosingParenthesis(state,funcToken-1);
--state.nextToken;
return {functionType,std::move(nonParameterLocalTypes),{}};
return {UINT32_MAX,{},{}};
}
}
\ No newline at end of file
......@@ -111,14 +111,16 @@ static void errorIfFollowsDefinitions(ModuleParseState& state)
}
template<typename Def,typename Type,typename DisassemblyName>
static void createImport(
static Uptr createImport(
ParseState& state,Name name,std::string&& moduleName,std::string&& exportName,
NameToIndexMap& nameToIndexMap,IndexSpace<Def,Type>& indexSpace,std::vector<DisassemblyName>& disassemblyNameArray,
Type type)
{
const Uptr importIndex = indexSpace.imports.size();
bindName(state,nameToIndexMap,name,indexSpace.size());
disassemblyNameArray.push_back({name.getString()});
indexSpace.imports.push_back({type,std::move(moduleName),std::move(exportName)});
return importIndex;
}
static bool parseOptionalSharedDeclaration(ModuleParseState& state)
......@@ -164,11 +166,17 @@ static void parseImport(ModuleParseState& state)
{
NameToIndexMap localNameToIndexMap;
std::vector<std::string> localDissassemblyNames;
const IndexedFunctionType functionType = parseFunctionTypeRef(state,localNameToIndexMap,localDissassemblyNames);
createImport(state,name,std::move(moduleName),std::move(exportName),
const UnresolvedFunctionType unresolvedFunctionType = parseFunctionTypeRefAndOrDecl(state,localNameToIndexMap,localDissassemblyNames);
const Uptr importIndex = createImport(state,name,std::move(moduleName),std::move(exportName),
state.functionNameToIndexMap,state.module.functions,state.disassemblyNames.functions,
functionType);
{UINT32_MAX});
state.disassemblyNames.functions.back().locals = localDissassemblyNames;
// Resolve the function import type after all type declarations have been parsed.
state.postTypeCallbacks.push_back([unresolvedFunctionType,importIndex](ModuleParseState& state)
{
state.module.functions.imports[importIndex].type = resolveFunctionType(state,unresolvedFunctionType);
});
break;
}
case t_table:
......@@ -373,7 +381,6 @@ static void parseElem(ModuleParseState& state)
parseElemSegmentBody(state,tableRef,baseIndex,elemToken);
}
template<typename Def,typename Type,typename ParseImport,typename ParseDef,typename DisassemblyName>
static void parseObjectDefOrImport(
ModuleParseState& state,
......@@ -432,9 +439,18 @@ static void parseFunc(ModuleParseState& state)
parseObjectDefOrImport(state,state.functionNameToIndexMap,state.module.functions,state.disassemblyNames.functions,t_func,ObjectKind::function,
[&](ModuleParseState& state)
{
// Parse the imported function's type.
NameToIndexMap localNameToIndexMap;
std::vector<std::string> localDisassemblyNames;
return parseFunctionTypeRef(state,localNameToIndexMap,localDisassemblyNames);
const UnresolvedFunctionType unresolvedFunctionType = parseFunctionTypeRefAndOrDecl(state,localNameToIndexMap,localDisassemblyNames);
// Resolve the function import type after all type declarations have been parsed.
const Uptr importIndex = state.module.functions.imports.size();
state.postTypeCallbacks.push_back([unresolvedFunctionType,importIndex](ModuleParseState& state)
{
state.module.functions.imports[importIndex].type = resolveFunctionType(state,unresolvedFunctionType);
});
return IndexedFunctionType {UINT32_MAX};
},
parseFunctionDef);
}
......@@ -580,6 +596,15 @@ namespace WAST
parseDeclaration(state);
};
// Process the callbacks requested after all type declarations have been parsed.
if(!state.errors.size())
{
for(const auto& callback : state.postTypeCallbacks)
{
callback(state);
}
}
// Process the callbacks requested after all declarations have been parsed.
if(!state.errors.size())
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册