llvm-project/mlir/lib/Tools/PDLL/Parser/Parser.cpp

3179 lines
118 KiB
C++

//===- Parser.cpp ---------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "Lexer.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Constraint.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <string>
using namespace mlir;
using namespace mlir::pdll;
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
namespace {
class Parser {
public:
Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
: ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
typeRangeTy(ast::TypeRangeType::get(ctx)),
valueRangeTy(ast::ValueRangeType::get(ctx)),
attrTy(ast::AttributeType::get(ctx)),
codeCompleteContext(codeCompleteContext) {}
/// Try to parse a new module. Returns nullptr in the case of failure.
FailureOr<ast::Module *> parseModule();
private:
/// The current context of the parser. It allows for the parser to know a bit
/// about the construct it is nested within during parsing. This is used
/// specifically to provide additional verification during parsing, e.g. to
/// prevent using rewrites within a match context, matcher constraints within
/// a rewrite section, etc.
enum class ParserContext {
/// The parser is in the global context.
Global,
/// The parser is currently within a Constraint, which disallows all types
/// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
Constraint,
/// The parser is currently within the matcher portion of a Pattern, which
/// is allows a terminal operation rewrite statement but no other rewrite
/// transformations.
PatternMatch,
/// The parser is currently within a Rewrite, which disallows calls to
/// constraints, requires operation expressions to have names, etc.
Rewrite,
};
/// The current specification context of an operations result type. This
/// indicates how the result types of an operation may be inferred.
enum class OpResultTypeContext {
/// The result types of the operation are not known to be inferred.
Explicit,
/// The result types of the operation are inferred from the root input of a
/// `replace` statement.
Replacement,
/// The result types of the operation are inferred by using the
/// `InferTypeOpInterface` interface provided by the operation.
Interface,
};
//===--------------------------------------------------------------------===//
// Parsing
//===--------------------------------------------------------------------===//
/// Push a new decl scope onto the lexer.
ast::DeclScope *pushDeclScope() {
ast::DeclScope *newScope =
new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
return (curDeclScope = newScope);
}
void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
/// Pop the last decl scope from the lexer.
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
/// Parse the body of an AST module.
LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
/// Try to convert the given expression to `type`. Returns failure and emits
/// an error if a conversion is not viable. On failure, `noteAttachFn` is
/// invoked to attach notes to the emitted error diagnostic. On success,
/// `expr` is updated to the expression used to convert to `type`.
LogicalResult convertExpressionTo(
ast::Expr *&expr, ast::Type type,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
LogicalResult
convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
ast::Type type,
function_ref<ast::InFlightDiagnostic()> emitErrorFn);
LogicalResult convertTupleExpressionTo(
ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
function_ref<ast::InFlightDiagnostic()> emitErrorFn,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
/// Given an operation expression, convert it to a Value or ValueRange
/// typed expression.
ast::Expr *convertOpToValue(const ast::Expr *opExpr);
/// Lookup ODS information for the given operation, returns nullptr if no
/// information is found.
const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
}
/// Process the given documentation string, or return an empty string if
/// documentation isn't enabled.
StringRef processDoc(StringRef doc) {
return enableDocumentation ? doc : StringRef();
}
/// Process the given documentation string and format it, or return an empty
/// string if documentation isn't enabled.
std::string processAndFormatDoc(const Twine &doc) {
if (!enableDocumentation)
return "";
std::string docStr;
{
llvm::raw_string_ostream docOS(docStr);
std::string tmpDocStr = doc.str();
raw_indented_ostream(docOS).printReindented(
StringRef(tmpDocStr).rtrim(" \t"));
}
return docStr;
}
//===--------------------------------------------------------------------===//
// Directives
LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
SmallVectorImpl<ast::Decl *> &decls);
/// Process the records of a parsed tablegen include file.
void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
SmallVectorImpl<ast::Decl *> &decls);
/// Create a user defined native constraint for a constraint imported from
/// ODS.
template <typename ConstraintT>
ast::Decl *
createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
SMRange loc, ast::Type type,
StringRef nativeType, StringRef docString);
template <typename ConstraintT>
ast::Decl *
createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
SMRange loc, ast::Type type,
StringRef nativeType);
//===--------------------------------------------------------------------===//
// Decls
/// This structure contains the set of pattern metadata that may be parsed.
struct ParsedPatternMetadata {
Optional<uint16_t> benefit;
bool hasBoundedRecursion = false;
};
FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *>
parseNamedAttributeDecl(Optional<StringRef> parentOpName);
/// Parse an argument variable as part of the signature of a
/// UserConstraintDecl or UserRewriteDecl.
FailureOr<ast::VariableDecl *> parseArgumentDecl();
/// Parse a result variable as part of the signature of a UserConstraintDecl
/// or UserRewriteDecl.
FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
/// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
/// defined in a non-global context.
FailureOr<ast::UserConstraintDecl *>
parseUserConstraintDecl(bool isInline = false);
/// Parse an inline UserConstraintDecl. An inline decl is one defined in a
/// non-global context, such as within a Pattern/Constraint/etc.
FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
/// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
/// PDLL constructs.
FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
const ast::Name &name, bool isInline,
ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
/// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
/// defined in a non-global context.
FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
/// Parse an inline UserRewriteDecl. An inline decl is one defined in a
/// non-global context, such as within a Pattern/Rewrite/etc.
FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
/// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
/// PDLL constructs.
FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
const ast::Name &name, bool isInline,
ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
/// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
/// effectively the same syntax, and only differ on slight semantics (given
/// the different parsing contexts).
template <typename T, typename ParseUserPDLLDeclFnT>
FailureOr<T *> parseUserConstraintOrRewriteDecl(
ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
StringRef anonymousNamePrefix, bool isInline);
/// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
/// These decls have effectively the same syntax.
template <typename T>
FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
const ast::Name &name, bool isInline,
ArrayRef<ast::VariableDecl *> arguments,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
/// Parse the functional signature (i.e. the arguments and results) of a
/// UserConstraintDecl or UserRewriteDecl.
LogicalResult parseUserConstraintOrRewriteSignature(
SmallVectorImpl<ast::VariableDecl *> &arguments,
SmallVectorImpl<ast::VariableDecl *> &results,
ast::DeclScope *&argumentScope, ast::Type &resultType);
/// Validate the return (which if present is specified by bodyIt) of a
/// UserConstraintDecl or UserRewriteDecl.
LogicalResult validateUserConstraintOrRewriteReturn(
StringRef declType, ast::CompoundStmt *body,
ArrayRef<ast::Stmt *>::iterator bodyIt,
ArrayRef<ast::Stmt *>::iterator bodyE,
ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
FailureOr<ast::CompoundStmt *>
parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon = true);
FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
FailureOr<ast::Decl *> parsePatternDecl();
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
/// Check to see if a decl has already been defined with the given name, if
/// one has emit and error and return failure. Returns success otherwise.
LogicalResult checkDefineNamedDecl(const ast::Name &name);
/// Try to define a variable decl with the given components, returns the
/// variable on success.
FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
ast::Expr *initExpr,
ArrayRef<ast::ConstraintRef> constraints);
FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
ArrayRef<ast::ConstraintRef> constraints);
/// Parse the constraint reference list for a variable decl.
LogicalResult parseVariableDeclConstraintList(
SmallVectorImpl<ast::ConstraintRef> &constraints);
/// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
FailureOr<ast::Expr *> parseTypeConstraintExpr();
/// Try to parse a single reference to a constraint. `typeConstraint` is the
/// location of a previously parsed type constraint for the entity that will
/// be constrained by the parsed constraint. `existingConstraints` are any
/// existing constraints that have already been parsed for the same entity
/// that will be constrained by this constraint. `allowInlineTypeConstraints`
/// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints,
bool allowInlineTypeConstraints);
/// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
/// argument or result variable. The constraints for these variables do not
/// allow inline type constraints, and only permit a single constraint.
FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
//===--------------------------------------------------------------------===//
// Exprs
FailureOr<ast::Expr *> parseExpr();
/// Identifier expressions.
FailureOr<ast::Expr *> parseAttributeExpr();
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
FailureOr<ast::Expr *> parseIdentifierExpr();
FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
FailureOr<ast::Expr *>
parseOperationExpr(OpResultTypeContext inputResultTypeContext =
OpResultTypeContext::Explicit);
FailureOr<ast::Expr *> parseTupleExpr();
FailureOr<ast::Expr *> parseTypeExpr();
FailureOr<ast::Expr *> parseUnderscoreExpr();
//===--------------------------------------------------------------------===//
// Stmts
FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
FailureOr<ast::CompoundStmt *> parseCompoundStmt();
FailureOr<ast::EraseStmt *> parseEraseStmt();
FailureOr<ast::LetStmt *> parseLetStmt();
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
FailureOr<ast::ReturnStmt *> parseReturnStmt();
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
//===--------------------------------------------------------------------===//
// Creation+Analysis
//===--------------------------------------------------------------------===//
//===--------------------------------------------------------------------===//
// Decls
/// Try to extract a callable from the given AST node. Returns nullptr on
/// failure.
ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
/// Try to create a pattern decl with the given components, returning the
/// Pattern on success.
FailureOr<ast::PatternDecl *>
createPatternDecl(SMRange loc, const ast::Name *name,
const ParsedPatternMetadata &metadata,
ast::CompoundStmt *body);
/// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
/// of results, defined as part of the signature.
ast::Type
createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
/// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
template <typename T>
FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
ast::CompoundStmt *body);
/// Try to create a variable decl with the given components, returning the
/// Variable on success.
FailureOr<ast::VariableDecl *>
createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints);
/// Create a variable for an argument or result defined as part of the
/// signature of a UserConstraintDecl/UserRewriteDecl.
FailureOr<ast::VariableDecl *>
createArgOrResultVariableDecl(StringRef name, SMRange loc,
const ast::ConstraintRef &constraint);
/// Validate the constraints used to constraint a variable decl.
/// `inferredType` is the type of the variable inferred by the constraints
/// within the list, and is updated to the most refined type as determined by
/// the constraints. Returns success if the constraint list is valid, failure
/// otherwise.
LogicalResult
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
ast::Type &inferredType);
/// Validate a single reference to a constraint. `inferredType` contains the
/// currently inferred variabled type and is refined within the type defined
/// by the constraint. Returns success if the constraint is valid, failure
/// otherwise.
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
ast::Type &inferredType);
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
//===--------------------------------------------------------------------===//
// Exprs
FailureOr<ast::CallExpr *>
createCallExpr(SMRange loc, ast::Expr *parentExpr,
MutableArrayRef<ast::Expr *> arguments);
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
ArrayRef<ast::ConstraintRef> constraints);
FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc);
FailureOr<ast::OperationExpr *>
createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
OpResultTypeContext resultTypeContext,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
SmallVectorImpl<ast::Expr *> &results);
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
SmallVectorImpl<ast::Expr *> &operands);
LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
SmallVectorImpl<ast::Expr *> &results);
void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
const ods::Operation *odsOp);
LogicalResult validateOperationOperandsOrResults(
StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
ast::RangeType rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames);
//===--------------------------------------------------------------------===//
// Stmts
FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
FailureOr<ast::ReplaceStmt *>
createReplaceStmt(SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues);
FailureOr<ast::RewriteStmt *>
createRewriteStmt(SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody);
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
/// The set of various code completion methods. Every completion method
/// returns `failure` to stop the parsing process after providing completion
/// results.
LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
LogicalResult codeCompleteConstraintName(ast::Type inferredType,
bool allowInlineTypeConstraints);
LogicalResult codeCompleteDialectName();
LogicalResult codeCompleteOperationName(StringRef dialectName);
LogicalResult codeCompletePatternMetadata();
LogicalResult codeCompleteIncludeFilename(StringRef curPath);
void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
unsigned currentNumOperands);
void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
unsigned currentNumResults);
//===--------------------------------------------------------------------===//
// Lexer Utilities
//===--------------------------------------------------------------------===//
/// If the current token has the specified kind, consume it and return true.
/// If not, return false.
bool consumeIf(Token::Kind kind) {
if (curToken.isNot(kind))
return false;
consumeToken(kind);
return true;
}
/// Advance the current lexer onto the next token.
void consumeToken() {
assert(curToken.isNot(Token::eof, Token::error) &&
"shouldn't advance past EOF or errors");
curToken = lexer.lexToken();
}
/// Advance the current lexer onto the next token, asserting what the expected
/// current token is. This is preferred to the above method because it leads
/// to more self-documenting code with better checking.
void consumeToken(Token::Kind kind) {
assert(curToken.is(kind) && "consumed an unexpected token");
consumeToken();
}
/// Reset the lexer to the location at the given position.
void resetToken(SMRange tokLoc) {
lexer.resetPointer(tokLoc.Start.getPointer());
curToken = lexer.lexToken();
}
/// Consume the specified token if present and return success. On failure,
/// output a diagnostic and return failure.
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
if (curToken.getKind() != kind)
return emitError(curToken.getLoc(), msg);
consumeToken();
return success();
}
LogicalResult emitError(SMRange loc, const Twine &msg) {
lexer.emitError(loc, msg);
return failure();
}
LogicalResult emitError(const Twine &msg) {
return emitError(curToken.getLoc(), msg);
}
LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
const Twine &note) {
lexer.emitErrorAndNote(loc, msg, noteLoc, note);
return failure();
}
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
/// The owning AST context.
ast::Context &ctx;
/// The lexer of this parser.
Lexer lexer;
/// The current token within the lexer.
Token curToken;
/// A flag indicating if the parser should add documentation to AST nodes when
/// viable.
bool enableDocumentation;
/// The most recently defined decl scope.
ast::DeclScope *curDeclScope = nullptr;
llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
/// The current context of the parser.
ParserContext parserContext = ParserContext::Global;
/// Cached types to simplify verification and expression creation.
ast::Type typeTy, valueTy;
ast::RangeType typeRangeTy, valueRangeTy;
ast::Type attrTy;
/// A counter used when naming anonymous constraints and rewrites.
unsigned anonymousDeclNameCounter = 0;
/// The optional code completion context.
CodeCompleteContext *codeCompleteContext;
};
} // namespace
FailureOr<ast::Module *> Parser::parseModule() {
SMLoc moduleLoc = curToken.getStartLoc();
pushDeclScope();
// Parse the top-level decls of the module.
SmallVector<ast::Decl *> decls;
if (failed(parseModuleBody(decls)))
return popDeclScope(), failure();
popDeclScope();
return ast::Module::create(ctx, moduleLoc, decls);
}
LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
while (curToken.isNot(Token::eof)) {
if (curToken.is(Token::directive)) {
if (failed(parseDirective(decls)))
return failure();
continue;
}
FailureOr<ast::Decl *> decl = parseTopLevelDecl();
if (failed(decl))
return failure();
decls.push_back(*decl);
}
return success();
}
ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
valueRangeTy);
}
LogicalResult Parser::convertExpressionTo(
ast::Expr *&expr, ast::Type type,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
ast::Type exprType = expr->getType();
if (exprType == type)
return success();
auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
expr->getLoc(), llvm::formatv("unable to convert expression of type "
"`{0}` to the expected type of "
"`{1}`",
exprType, type));
if (noteAttachFn)
noteAttachFn(*diag);
return diag;
};
if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
// FIXME: Decide how to allow/support converting a single result to multiple,
// and multiple to a single result. For now, we just allow Single->Range,
// but this isn't something really supported in the PDL dialect. We should
// figure out some way to support both.
if ((exprType == valueTy || exprType == valueRangeTy) &&
(type == valueTy || type == valueRangeTy))
return success();
if ((exprType == typeTy || exprType == typeRangeTy) &&
(type == typeTy || type == typeRangeTy))
return success();
// Handle tuple types.
if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
noteAttachFn);
return emitConvertError();
}
LogicalResult Parser::convertOpExpressionTo(
ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
// Two operation types are compatible if they have the same name, or if the
// expected type is more general.
if (auto opType = type.dyn_cast<ast::OperationType>()) {
if (opType.getName())
return emitErrorFn();
return success();
}
// An operation can always convert to a ValueRange.
if (type == valueRangeTy) {
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
valueRangeTy);
return success();
}
// Allow conversion to a single value by constraining the result range.
if (type == valueTy) {
// If the operation is registered, we can verify if it can ever have a
// single result.
if (const ods::Operation *odsOp = exprType.getODSOperation()) {
if (odsOp->getResults().empty()) {
return emitErrorFn()->attachNote(
llvm::formatv("see the definition of `{0}`, which was defined "
"with zero results",
odsOp->getName()),
odsOp->getLoc());
}
unsigned numSingleResults = llvm::count_if(
odsOp->getResults(), [](const ods::OperandOrResult &result) {
return result.getVariableLengthKind() ==
ods::VariableLengthKind::Single;
});
if (numSingleResults > 1) {
return emitErrorFn()->attachNote(
llvm::formatv("see the definition of `{0}`, which was defined "
"with at least {1} results",
odsOp->getName(), numSingleResults),
odsOp->getLoc());
}
}
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
valueTy);
return success();
}
return emitErrorFn();
}
LogicalResult Parser::convertTupleExpressionTo(
ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
function_ref<ast::InFlightDiagnostic()> emitErrorFn,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
// Handle conversions between tuples.
if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
if (tupleType.size() != exprType.size())
return emitErrorFn();
// Build a new tuple expression using each of the elements of the current
// tuple.
SmallVector<ast::Expr *> newExprs;
for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
newExprs.push_back(ast::MemberAccessExpr::create(
ctx, expr->getLoc(), expr, llvm::to_string(i),
exprType.getElementTypes()[i]));
auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
i, exprType));
if (noteAttachFn)
noteAttachFn(diag);
};
if (failed(convertExpressionTo(newExprs.back(),
tupleType.getElementTypes()[i], diagFn)))
return failure();
}
expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
tupleType.getElementNames());
return success();
}
// Handle conversion to a range.
auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
ast::RangeType resultTy) -> LogicalResult {
// TODO: We currently only allow range conversion within a rewrite context.
if (parserContext != ParserContext::Rewrite) {
return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
"only allowed within a rewrite context");
}
// All of the tuple elements must be allowed types.
for (ast::Type elementType : exprType.getElementTypes())
if (!llvm::is_contained(allowedElementTypes, elementType))
return emitErrorFn();
// Build a new tuple expression using each of the elements of the current
// tuple.
SmallVector<ast::Expr *> newExprs;
for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
newExprs.push_back(ast::MemberAccessExpr::create(
ctx, expr->getLoc(), expr, llvm::to_string(i),
exprType.getElementTypes()[i]));
}
expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
return success();
};
if (type == valueRangeTy)
return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
if (type == typeRangeTy)
return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
return emitErrorFn();
}
//===----------------------------------------------------------------------===//
// Directives
LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
StringRef directive = curToken.getSpelling();
if (directive == "#include")
return parseInclude(decls);
return emitError("unknown directive `" + directive + "`");
}
LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
SMRange loc = curToken.getLoc();
consumeToken(Token::directive);
// Handle code completion of the include file path.
if (curToken.is(Token::code_complete_string))
return codeCompleteIncludeFilename(curToken.getStringValue());
// Parse the file being included.
if (!curToken.isString())
return emitError(loc,
"expected string file name after `include` directive");
SMRange fileLoc = curToken.getLoc();
std::string filenameStr = curToken.getStringValue();
StringRef filename = filenameStr;
consumeToken();
// Check the type of include. If ending with `.pdll`, this is another pdl file
// to be parsed along with the current module.
if (filename.endswith(".pdll")) {
if (failed(lexer.pushInclude(filename, fileLoc)))
return emitError(fileLoc,
"unable to open include file `" + filename + "`");
// If we added the include successfully, parse it into the current module.
// Make sure to update to the next token after we finish parsing the nested
// file.
curToken = lexer.lexToken();
LogicalResult result = parseModuleBody(decls);
curToken = lexer.lexToken();
return result;
}
// Otherwise, this must be a `.td` include.
if (filename.endswith(".td"))
return parseTdInclude(filename, fileLoc, decls);
return emitError(fileLoc,
"expected include filename to end with `.pdll` or `.td`");
}
LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
SmallVectorImpl<ast::Decl *> &decls) {
llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
// Use the source manager to open the file, but don't yet add it.
std::string includedFile;
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
if (!includeBuffer)
return emitError(fileLoc, "unable to open include file `" + filename + "`");
// Setup the source manager for parsing the tablegen file.
llvm::SourceMgr tdSrcMgr;
tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
// This class provides a context argument for the llvm::SourceMgr diagnostic
// handler.
struct DiagHandlerContext {
Parser &parser;
StringRef filename;
llvm::SMRange loc;
} handlerContext{*this, filename, fileLoc};
// Set the diagnostic handler for the tablegen source manager.
tdSrcMgr.setDiagHandler(
[](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
(void)ctx->parser.emitError(
ctx->loc,
llvm::formatv("error while processing include file `{0}`: {1}",
ctx->filename, diag.getMessage()));
},
&handlerContext);
// Parse the tablegen file.
llvm::RecordKeeper tdRecords;
if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
return failure();
// Process the parsed records.
processTdIncludeRecords(tdRecords, decls);
// After we are done processing, move all of the tablegen source buffers to
// the main parser source mgr. This allows for directly using source locations
// from the .td files without needing to remap them.
parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
return success();
}
void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
SmallVectorImpl<ast::Decl *> &decls) {
// Return the length kind of the given value.
auto getLengthKind = [](const auto &value) {
if (value.isOptional())
return ods::VariableLengthKind::Optional;
return value.isVariadic() ? ods::VariableLengthKind::Variadic
: ods::VariableLengthKind::Single;
};
// Insert a type constraint into the ODS context.
ods::Context &odsContext = ctx.getODSContext();
auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
-> const ods::TypeConstraint & {
return odsContext.insertTypeConstraint(
cst.constraint.getUniqueDefName(),
processDoc(cst.constraint.getSummary()),
cst.constraint.getCPPClassName());
};
auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
};
// Process the parsed tablegen records to build ODS information.
/// Operations.
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
tblgen::Operator op(def);
// Check to see if this operation is known to support type inferrence.
bool supportsResultTypeInferrence =
op.getTrait("::mlir::InferTypeOpInterface::Trait");
auto [odsOp, inserted] = odsContext.insertOperation(
op.getOperationName(), processDoc(op.getSummary()),
processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
supportsResultTypeInferrence, op.getLoc().front());
// Ignore operations that have already been added.
if (!inserted)
continue;
for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
odsContext.insertAttributeConstraint(
attr.attr.getUniqueDefName(),
processDoc(attr.attr.getSummary()),
attr.attr.getStorageType()));
}
for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
odsOp->appendOperand(operand.name, getLengthKind(operand),
addTypeConstraint(operand));
}
for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
odsOp->appendResult(result.name, getLengthKind(result),
addTypeConstraint(result));
}
}
auto shouldBeSkipped = [this](llvm::Record *def) {
return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
def->isSubClassOf("DeclareInterfaceMethods");
};
/// Attr constraints.
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
if (shouldBeSkipped(def))
continue;
tblgen::Attribute constraint(def);
decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
constraint, convertLocToRange(def->getLoc().front()), attrTy,
constraint.getStorageType()));
}
/// Type constraints.
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
if (shouldBeSkipped(def))
continue;
tblgen::TypeConstraint constraint(def);
decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
constraint, convertLocToRange(def->getLoc().front()), typeTy,
constraint.getCPPClassName()));
}
/// OpInterfaces.
ast::Type opTy = ast::OperationType::get(ctx);
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
if (shouldBeSkipped(def))
continue;
SMRange loc = convertLocToRange(def->getLoc().front());
std::string cppClassName =
llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
def->getValueAsString("cppInterfaceName"))
.str();
std::string codeBlock =
llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
cppClassName)
.str();
std::string desc =
processAndFormatDoc(def->getValueAsString("description"));
decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
def->getName(), codeBlock, loc, opTy, cppClassName, desc));
}
}
template <typename ConstraintT>
ast::Decl *Parser::createODSNativePDLLConstraintDecl(
StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
StringRef nativeType, StringRef docString) {
// Build the single input parameter.
ast::DeclScope *argScope = pushDeclScope();
auto *paramVar = ast::VariableDecl::create(
ctx, ast::Name::create(ctx, "self", loc), type,
/*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
argScope->add(paramVar);
popDeclScope();
// Build the native constraint.
auto *constraintDecl = ast::UserConstraintDecl::createNative(
ctx, ast::Name::create(ctx, name, loc), paramVar,
/*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
nativeType);
constraintDecl->setDocComment(ctx, docString);
curDeclScope->add(constraintDecl);
return constraintDecl;
}
template <typename ConstraintT>
ast::Decl *
Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
SMRange loc, ast::Type type,
StringRef nativeType) {
// Format the condition template.
tblgen::FmtContext fmtContext;
fmtContext.withSelf("self");
std::string codeBlock = tblgen::tgfmt(
"return ::mlir::success(" + constraint.getConditionTemplate() + ");",
&fmtContext);
// If documentation was enabled, build the doc string for the generated
// constraint. It would be nice to do this lazily, but TableGen information is
// destroyed after we finish parsing the file.
std::string docString;
if (enableDocumentation) {
StringRef desc = constraint.getDescription();
docString = processAndFormatDoc(
constraint.getSummary() +
(desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
}
return createODSNativePDLLConstraintDecl<ConstraintT>(
constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
docString);
}
//===----------------------------------------------------------------------===//
// Decls
FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
FailureOr<ast::Decl *> decl;
switch (curToken.getKind()) {
case Token::kw_Constraint:
decl = parseUserConstraintDecl();
break;
case Token::kw_Pattern:
decl = parsePatternDecl();
break;
case Token::kw_Rewrite:
decl = parseUserRewriteDecl();
break;
default:
return emitError("expected top-level declaration, such as a `Pattern`");
}
if (failed(decl))
return failure();
// If the decl has a name, add it to the current scope.
if (const ast::Name *name = (*decl)->getName()) {
if (failed(checkDefineNamedDecl(*name)))
return failure();
curDeclScope->add(*decl);
}
return decl;
}
FailureOr<ast::NamedAttributeDecl *>
Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
// Check for name code completion.
if (curToken.is(Token::code_complete))
return codeCompleteAttributeName(parentOpName);
std::string attrNameStr;
if (curToken.isString())
attrNameStr = curToken.getStringValue();
else if (curToken.is(Token::identifier) || curToken.isKeyword())
attrNameStr = curToken.getSpelling().str();
else
return emitError("expected identifier or string attribute name");
const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
consumeToken();
// Check for a value of the attribute.
ast::Expr *attrValue = nullptr;
if (consumeIf(Token::equal)) {
FailureOr<ast::Expr *> attrExpr = parseExpr();
if (failed(attrExpr))
return failure();
attrValue = *attrExpr;
} else {
// If there isn't a concrete value, create an expression representing a
// UnitAttr.
attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
}
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
}
FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon) {
consumeToken(Token::equal_arrow);
// Parse the single statement of the lambda body.
SMLoc bodyStartLoc = curToken.getStartLoc();
pushDeclScope();
FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
bool failedToParse =
failed(singleStatement) || failed(processStatementFn(*singleStatement));
popDeclScope();
if (failedToParse)
return failure();
SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
}
FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
// Ensure that the argument is named.
if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
return emitError("expected identifier argument name");
// Parse the argument similarly to a normal variable.
StringRef name = curToken.getSpelling();
SMRange nameLoc = curToken.getLoc();
consumeToken();
if (failed(
parseToken(Token::colon, "expected `:` before argument constraint")))
return failure();
FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
if (failed(cst))
return failure();
return createArgOrResultVariableDecl(name, nameLoc, *cst);
}
FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
// Check to see if this result is named.
if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
// Check to see if this name actually refers to a Constraint.
if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
// If it wasn't a constraint, parse the result similarly to a variable. If
// there is already an existing decl, we will emit an error when defining
// this variable later.
StringRef name = curToken.getSpelling();
SMRange nameLoc = curToken.getLoc();
consumeToken();
if (failed(parseToken(Token::colon,
"expected `:` before result constraint")))
return failure();
FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
if (failed(cst))
return failure();
return createArgOrResultVariableDecl(name, nameLoc, *cst);
}
}
// If it isn't named, we parse the constraint directly and create an unnamed
// result variable.
FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
if (failed(cst))
return failure();
return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
}
FailureOr<ast::UserConstraintDecl *>
Parser::parseUserConstraintDecl(bool isInline) {
// Constraints and rewrites have very similar formats, dispatch to a shared
// interface for parsing.
return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
[&](auto &&...args) {
return this->parseUserPDLLConstraintDecl(args...);
},
ParserContext::Constraint, "constraint", isInline);
}
FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
FailureOr<ast::UserConstraintDecl *> decl =
parseUserConstraintDecl(/*isInline=*/true);
if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
return failure();
curDeclScope->add(*decl);
return decl;
}
FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
const ast::Name &name, bool isInline,
ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
// Push the argument scope back onto the list, so that the body can
// reference arguments.
pushDeclScope(argumentScope);
// Parse the body of the constraint. The body is either defined as a compound
// block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
ast::CompoundStmt *body;
if (curToken.is(Token::equal_arrow)) {
FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
[&](ast::Stmt *&stmt) -> LogicalResult {
ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
if (!stmtExpr) {
return emitError(stmt->getLoc(),
"expected `Constraint` lambda body to contain a "
"single expression");
}
stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
return success();
},
/*expectTerminalSemicolon=*/!isInline);
if (failed(bodyResult))
return failure();
body = *bodyResult;
} else {
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
// Verify the structure of the body.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt)
if (isa<ast::ReturnStmt>(*bodyIt))
break;
if (failed(validateUserConstraintOrRewriteReturn(
"Constraint", body, bodyIt, bodyE, results, resultType)))
return failure();
}
popDeclScope();
return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
name, arguments, results, resultType, body);
}
FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
// Constraints and rewrites have very similar formats, dispatch to a shared
// interface for parsing.
return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
[&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
ParserContext::Rewrite, "rewrite", isInline);
}
FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
FailureOr<ast::UserRewriteDecl *> decl =
parseUserRewriteDecl(/*isInline=*/true);
if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
return failure();
curDeclScope->add(*decl);
return decl;
}
FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
const ast::Name &name, bool isInline,
ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
// Push the argument scope back onto the list, so that the body can
// reference arguments.
curDeclScope = argumentScope;
ast::CompoundStmt *body;
if (curToken.is(Token::equal_arrow)) {
FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
[&](ast::Stmt *&statement) -> LogicalResult {
if (isa<ast::OpRewriteStmt>(statement))
return success();
ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
if (!statementExpr) {
return emitError(
statement->getLoc(),
"expected `Rewrite` lambda body to contain a single expression "
"or an operation rewrite statement; such as `erase`, "
"`replace`, or `rewrite`");
}
statement =
ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
return success();
},
/*expectTerminalSemicolon=*/!isInline);
if (failed(bodyResult))
return failure();
body = *bodyResult;
} else {
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
}
popDeclScope();
// Verify the structure of the body.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt)
if (isa<ast::ReturnStmt>(*bodyIt))
break;
if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
bodyE, results, resultType)))
return failure();
return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
name, arguments, results, resultType, body);
}
template <typename T, typename ParseUserPDLLDeclFnT>
FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
StringRef anonymousNamePrefix, bool isInline) {
SMRange loc = curToken.getLoc();
consumeToken();
llvm::SaveAndRestore saveCtx(parserContext, declContext);
// Parse the name of the decl.
const ast::Name *name = nullptr;
if (curToken.isNot(Token::identifier)) {
// Only inline decls can be un-named. Inline decls are similar to "lambdas"
// in C++, so being unnamed is fine.
if (!isInline)
return emitError("expected identifier name");
// Create a unique anonymous name to use, as the name for this decl is not
// important.
std::string anonName =
llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
anonymousDeclNameCounter++)
.str();
name = &ast::Name::create(ctx, anonName, loc);
} else {
// If a name was provided, we can use it directly.
name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
consumeToken(Token::identifier);
}
// Parse the functional signature of the decl.
SmallVector<ast::VariableDecl *> arguments, results;
ast::DeclScope *argumentScope;
ast::Type resultType;
if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
argumentScope, resultType)))
return failure();
// Check to see which type of constraint this is. If the constraint contains a
// compound body, this is a PDLL decl.
if (curToken.isAny(Token::l_brace, Token::equal_arrow))
return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
resultType);
// Otherwise, this is a native decl.
return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
results, resultType);
}
template <typename T>
FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
const ast::Name &name, bool isInline,
ArrayRef<ast::VariableDecl *> arguments,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
// If followed by a string, the native code body has also been specified.
std::string codeStrStorage;
Optional<StringRef> optCodeStr;
if (curToken.isString()) {
codeStrStorage = curToken.getStringValue();
optCodeStr = codeStrStorage;
consumeToken();
} else if (isInline) {
return emitError(name.getLoc(),
"external declarations must be declared in global scope");
} else if (curToken.is(Token::error)) {
return failure();
}
if (failed(parseToken(Token::semicolon,
"expected `;` after native declaration")))
return failure();
// TODO: PDL should be able to support constraint results in certain
// situations, we should revise this.
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
return emitError(
"native Constraints currently do not support returning results");
}
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
}
LogicalResult Parser::parseUserConstraintOrRewriteSignature(
SmallVectorImpl<ast::VariableDecl *> &arguments,
SmallVectorImpl<ast::VariableDecl *> &results,
ast::DeclScope *&argumentScope, ast::Type &resultType) {
// Parse the argument list of the decl.
if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
return failure();
argumentScope = pushDeclScope();
if (curToken.isNot(Token::r_paren)) {
do {
FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
if (failed(argument))
return failure();
arguments.emplace_back(*argument);
} while (consumeIf(Token::comma));
}
popDeclScope();
if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
return failure();
// Parse the results of the decl.
pushDeclScope();
if (consumeIf(Token::arrow)) {
auto parseResultFn = [&]() -> LogicalResult {
FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
if (failed(result))
return failure();
results.emplace_back(*result);
return success();
};
// Check for a list of results.
if (consumeIf(Token::l_paren)) {
do {
if (failed(parseResultFn()))
return failure();
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
return failure();
// Otherwise, there is only one result.
} else if (failed(parseResultFn())) {
return failure();
}
}
popDeclScope();
// Compute the result type of the decl.
resultType = createUserConstraintRewriteResultType(results);
// Verify that results are only named if there are more than one.
if (results.size() == 1 && !results.front()->getName().getName().empty()) {
return emitError(
results.front()->getLoc(),
"cannot create a single-element tuple with an element label");
}
return success();
}
LogicalResult Parser::validateUserConstraintOrRewriteReturn(
StringRef declType, ast::CompoundStmt *body,
ArrayRef<ast::Stmt *>::iterator bodyIt,
ArrayRef<ast::Stmt *>::iterator bodyE,
ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
// Handle if a `return` was provided.
if (bodyIt != bodyE) {
// Emit an error if we have trailing statements after the return.
if (std::next(bodyIt) != bodyE) {
return emitError(
(*std::next(bodyIt))->getLoc(),
llvm::formatv("`return` terminated the `{0}` body, but found "
"trailing statements afterwards",
declType));
}
// Otherwise if a return wasn't provided, check that no results are
// expected.
} else if (!results.empty()) {
return emitError(
{body->getLoc().End, body->getLoc().End},
llvm::formatv("missing return in a `{0}` expected to return `{1}`",
declType, resultType));
}
return success();
}
FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
if (isa<ast::OpRewriteStmt>(statement))
return success();
return emitError(
statement->getLoc(),
"expected Pattern lambda body to contain a single operation "
"rewrite statement, such as `erase`, `replace`, or `rewrite`");
});
}
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_Pattern);
llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
// Check for an optional identifier for the pattern name.
const ast::Name *name = nullptr;
if (curToken.is(Token::identifier)) {
name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
consumeToken(Token::identifier);
}
// Parse any pattern metadata.
ParsedPatternMetadata metadata;
if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
return failure();
// Parse the pattern body.
ast::CompoundStmt *body;
// Handle a lambda body.
if (curToken.is(Token::equal_arrow)) {
FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
if (failed(bodyResult))
return failure();
body = *bodyResult;
} else {
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` or `=>` to start pattern body");
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
// Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) {
if (isa<ast::ReturnStmt>(*bodyIt)) {
return emitError((*bodyIt)->getLoc(),
"`return` statements are only permitted within a "
"`Constraint` or `Rewrite` body");
}
// Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt))
break;
}
if (bodyIt == bodyE) {
return emitError(loc,
"expected Pattern body to terminate with an operation "
"rewrite statement, such as `erase`");
}
if (std::next(bodyIt) != bodyE) {
return emitError((*std::next(bodyIt))->getLoc(),
"Pattern body was terminated by an operation "
"rewrite statement, but found trailing statements");
}
}
return createPatternDecl(loc, name, metadata, body);
}
LogicalResult
Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
Optional<SMRange> benefitLoc;
Optional<SMRange> hasBoundedRecursionLoc;
do {
// Handle metadata code completion.
if (curToken.is(Token::code_complete))
return codeCompletePatternMetadata();
if (curToken.isNot(Token::identifier))
return emitError("expected pattern metadata identifier");
StringRef metadataStr = curToken.getSpelling();
SMRange metadataLoc = curToken.getLoc();
consumeToken(Token::identifier);
// Parse the benefit metadata: benefit(<integer-value>)
if (metadataStr == "benefit") {
if (benefitLoc) {
return emitErrorAndNote(metadataLoc,
"pattern benefit has already been specified",
*benefitLoc, "see previous definition here");
}
if (failed(parseToken(Token::l_paren,
"expected `(` before pattern benefit")))
return failure();
uint16_t benefitValue = 0;
if (curToken.isNot(Token::integer))
return emitError("expected integral pattern benefit");
if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
return emitError(
"expected pattern benefit to fit within a 16-bit integer");
consumeToken(Token::integer);
metadata.benefit = benefitValue;
benefitLoc = metadataLoc;
if (failed(
parseToken(Token::r_paren, "expected `)` after pattern benefit")))
return failure();
continue;
}
// Parse the bounded recursion metadata: recursion
if (metadataStr == "recursion") {
if (hasBoundedRecursionLoc) {
return emitErrorAndNote(
metadataLoc,
"pattern recursion metadata has already been specified",
*hasBoundedRecursionLoc, "see previous definition here");
}
metadata.hasBoundedRecursion = true;
hasBoundedRecursionLoc = metadataLoc;
continue;
}
return emitError(metadataLoc, "unknown pattern metadata");
} while (consumeIf(Token::comma));
return success();
}
FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
consumeToken(Token::less);
FailureOr<ast::Expr *> typeExpr = parseExpr();
if (failed(typeExpr) ||
failed(parseToken(Token::greater,
"expected `>` after variable type constraint")))
return failure();
return typeExpr;
}
LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
assert(curDeclScope && "defining decl outside of a decl scope");
if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
return emitErrorAndNote(
name.getLoc(), "`" + name.getName() + "` has already been defined",
lastDecl->getName()->getLoc(), "see previous definition here");
}
return success();
}
FailureOr<ast::VariableDecl *>
Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
ast::Expr *initExpr,
ArrayRef<ast::ConstraintRef> constraints) {
assert(curDeclScope && "defining variable outside of decl scope");
const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
// If the name of the variable indicates a special variable, we don't add it
// to the scope. This variable is local to the definition point.
if (name.empty() || name == "_") {
return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
constraints);
}
if (failed(checkDefineNamedDecl(nameDecl)))
return failure();
auto *varDecl =
ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
curDeclScope->add(varDecl);
return varDecl;
}
FailureOr<ast::VariableDecl *>
Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
ArrayRef<ast::ConstraintRef> constraints) {
return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
constraints);
}
LogicalResult Parser::parseVariableDeclConstraintList(
SmallVectorImpl<ast::ConstraintRef> &constraints) {
Optional<SMRange> typeConstraint;
auto parseSingleConstraint = [&] {
FailureOr<ast::ConstraintRef> constraint = parseConstraint(
typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
if (failed(constraint))
return failure();
constraints.push_back(*constraint);
return success();
};
// Check to see if this is a single constraint, or a list.
if (!consumeIf(Token::l_square))
return parseSingleConstraint();
do {
if (failed(parseSingleConstraint()))
return failure();
} while (consumeIf(Token::comma));
return parseToken(Token::r_square, "expected `]` after constraint list");
}
FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints,
bool allowInlineTypeConstraints) {
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
if (!allowInlineTypeConstraints) {
return emitError(
curToken.getLoc(),
"inline `Attr`, `Value`, and `ValueRange` type constraints are not "
"permitted on arguments or results");
}
if (typeConstraint)
return emitErrorAndNote(
curToken.getLoc(),
"the type of this variable has already been constrained",
*typeConstraint, "see previous constraint location here");
FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
if (failed(constraintExpr))
return failure();
typeExpr = *constraintExpr;
typeConstraint = typeExpr->getLoc();
return success();
};
SMRange loc = curToken.getLoc();
switch (curToken.getKind()) {
case Token::kw_Attr: {
consumeToken(Token::kw_Attr);
// Check for a type constraint.
ast::Expr *typeExpr = nullptr;
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
return failure();
return ast::ConstraintRef(
ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
}
case Token::kw_Op: {
consumeToken(Token::kw_Op);
// Parse an optional operation name. If the name isn't provided, this refers
// to "any" operation.
FailureOr<ast::OpNameDecl *> opName =
parseWrappedOperationName(/*allowEmptyName=*/true);
if (failed(opName))
return failure();
return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
loc);
}
case Token::kw_Type:
consumeToken(Token::kw_Type);
return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
case Token::kw_TypeRange:
consumeToken(Token::kw_TypeRange);
return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
loc);
case Token::kw_Value: {
consumeToken(Token::kw_Value);
// Check for a type constraint.
ast::Expr *typeExpr = nullptr;
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
return failure();
return ast::ConstraintRef(
ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
}
case Token::kw_ValueRange: {
consumeToken(Token::kw_ValueRange);
// Check for a type constraint.
ast::Expr *typeExpr = nullptr;
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
return failure();
return ast::ConstraintRef(
ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
}
case Token::kw_Constraint: {
// Handle an inline constraint.
FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
if (failed(decl))
return failure();
return ast::ConstraintRef(*decl, loc);
}
case Token::identifier: {
StringRef constraintName = curToken.getSpelling();
consumeToken(Token::identifier);
// Lookup the referenced constraint.
ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
if (!cstDecl) {
return emitError(loc, "unknown reference to constraint `" +
constraintName + "`");
}
// Handle a reference to a proper constraint.
if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
return ast::ConstraintRef(cst, loc);
return emitErrorAndNote(
loc, "invalid reference to non-constraint", cstDecl->getLoc(),
"see the definition of `" + constraintName + "` here");
}
// Handle single entity constraint code completion.
case Token::code_complete: {
// Try to infer the current type for use by code completion.
ast::Type inferredType;
if (failed(validateVariableConstraints(existingConstraints, inferredType)))
return failure();
return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
}
default:
break;
}
return emitError(loc, "expected identifier constraint");
}
FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
Optional<SMRange> typeConstraint;
return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
/*allowInlineTypeConstraints=*/false);
}
//===----------------------------------------------------------------------===//
// Exprs
FailureOr<ast::Expr *> Parser::parseExpr() {
if (curToken.is(Token::underscore))
return parseUnderscoreExpr();
// Parse the LHS expression.
FailureOr<ast::Expr *> lhsExpr;
switch (curToken.getKind()) {
case Token::kw_attr:
lhsExpr = parseAttributeExpr();
break;
case Token::kw_Constraint:
lhsExpr = parseInlineConstraintLambdaExpr();
break;
case Token::identifier:
lhsExpr = parseIdentifierExpr();
break;
case Token::kw_op:
lhsExpr = parseOperationExpr();
break;
case Token::kw_Rewrite:
lhsExpr = parseInlineRewriteLambdaExpr();
break;
case Token::kw_type:
lhsExpr = parseTypeExpr();
break;
case Token::l_paren:
lhsExpr = parseTupleExpr();
break;
default:
return emitError("expected expression");
}
if (failed(lhsExpr))
return failure();
// Check for an operator expression.
while (true) {
switch (curToken.getKind()) {
case Token::dot:
lhsExpr = parseMemberAccessExpr(*lhsExpr);
break;
case Token::l_paren:
lhsExpr = parseCallExpr(*lhsExpr);
break;
default:
return lhsExpr;
}
if (failed(lhsExpr))
return failure();
}
}
FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_attr);
// If we aren't followed by a `<`, the `attr` keyword is treated as a normal
// identifier.
if (!consumeIf(Token::less)) {
resetToken(loc);
return parseIdentifierExpr();
}
if (!curToken.isString())
return emitError("expected string literal containing MLIR attribute");
std::string attrExpr = curToken.getStringValue();
consumeToken();
loc.End = curToken.getEndLoc();
if (failed(
parseToken(Token::greater, "expected `>` after attribute literal")))
return failure();
return ast::AttributeExpr::create(ctx, loc, attrExpr);
}
FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
consumeToken(Token::l_paren);
// Parse the arguments of the call.
SmallVector<ast::Expr *> arguments;
if (curToken.isNot(Token::r_paren)) {
do {
// Handle code completion for the call arguments.
if (curToken.is(Token::code_complete)) {
codeCompleteCallSignature(parentExpr, arguments.size());
return failure();
}
FailureOr<ast::Expr *> argument = parseExpr();
if (failed(argument))
return failure();
arguments.push_back(*argument);
} while (consumeIf(Token::comma));
}
SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
return failure();
return createCallExpr(loc, parentExpr, arguments);
}
FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
ast::Decl *decl = curDeclScope->lookup(name);
if (!decl)
return emitError(loc, "undefined reference to `" + name + "`");
return createDeclRefExpr(loc, decl);
}
FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
StringRef name = curToken.getSpelling();
SMRange nameLoc = curToken.getLoc();
consumeToken();
// Check to see if this is a decl ref expression that defines a variable
// inline.
if (consumeIf(Token::colon)) {
SmallVector<ast::ConstraintRef> constraints;
if (failed(parseVariableDeclConstraintList(constraints)))
return failure();
ast::Type type;
if (failed(validateVariableConstraints(constraints, type)))
return failure();
return createInlineVariableExpr(type, name, nameLoc, constraints);
}
return parseDeclRefExpr(name, nameLoc);
}
FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
if (failed(decl))
return failure();
return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
ast::ConstraintType::get(ctx));
}
FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
if (failed(decl))
return failure();
return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
ast::RewriteType::get(ctx));
}
FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
SMRange dotLoc = curToken.getLoc();
consumeToken(Token::dot);
// Check for code completion of the member name.
if (curToken.is(Token::code_complete))
return codeCompleteMemberAccess(parentExpr);
// Parse the member name.
Token memberNameTok = curToken;
if (memberNameTok.isNot(Token::identifier, Token::integer) &&
!memberNameTok.isKeyword())
return emitError(dotLoc, "expected identifier or numeric member name");
StringRef memberName = memberNameTok.getSpelling();
SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
consumeToken();
return createMemberAccessExpr(parentExpr, memberName, loc);
}
FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
SMRange loc = curToken.getLoc();
// Check for code completion for the dialect name.
if (curToken.is(Token::code_complete))
return codeCompleteDialectName();
// Handle the case of an no operation name.
if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
if (allowEmptyName)
return ast::OpNameDecl::create(ctx, SMRange());
return emitError("expected dialect namespace");
}
StringRef name = curToken.getSpelling();
consumeToken();
// Otherwise, this is a literal operation name.
if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
return failure();
// Check for code completion for the operation name.
if (curToken.is(Token::code_complete))
return codeCompleteOperationName(name);
if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
return emitError("expected operation name after dialect namespace");
name = StringRef(name.data(), name.size() + 1);
do {
name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
loc.End = curToken.getEndLoc();
consumeToken();
} while (curToken.isAny(Token::identifier, Token::dot) ||
curToken.isKeyword());
return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
}
FailureOr<ast::OpNameDecl *>
Parser::parseWrappedOperationName(bool allowEmptyName) {
if (!consumeIf(Token::less))
return ast::OpNameDecl::create(ctx, SMRange());
FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
if (failed(opNameDecl))
return failure();
if (failed(parseToken(Token::greater, "expected `>` after operation name")))
return failure();
return opNameDecl;
}
FailureOr<ast::Expr *>
Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_op);
// If it isn't followed by a `<`, the `op` keyword is treated as a normal
// identifier.
if (curToken.isNot(Token::less)) {
resetToken(loc);
return parseIdentifierExpr();
}
// Parse the operation name. The name may be elided, in which case the
// operation refers to "any" operation(i.e. a difference between `MyOp` and
// `Operation*`). Operation names within a rewrite context must be named.
bool allowEmptyName = parserContext != ParserContext::Rewrite;
FailureOr<ast::OpNameDecl *> opNameDecl =
parseWrappedOperationName(allowEmptyName);
if (failed(opNameDecl))
return failure();
Optional<StringRef> opName = (*opNameDecl)->getName();
// Functor used to create an implicit range variable, used for implicit "all"
// operand or results variables.
auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
FailureOr<ast::VariableDecl *> rangeVar =
defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
assert(succeeded(rangeVar) && "expected range variable to be valid");
return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
};
// Check for the optional list of operands.
SmallVector<ast::Expr *> operands;
if (!consumeIf(Token::l_paren)) {
// If the operand list isn't specified and we are in a match context, define
// an inplace unconstrained operand range corresponding to all of the
// operands of the operation. This avoids treating zero operands the same
// way as "unconstrained operands".
if (parserContext != ParserContext::Rewrite) {
operands.push_back(createImplicitRangeVar(
ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
}
} else if (!consumeIf(Token::r_paren)) {
// If the operand list was specified and non-empty, parse the operands.
do {
// Check for operand signature code completion.
if (curToken.is(Token::code_complete)) {
codeCompleteOperationOperandsSignature(opName, operands.size());
return failure();
}
FailureOr<ast::Expr *> operand = parseExpr();
if (failed(operand))
return failure();
operands.push_back(*operand);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren,
"expected `)` after operation operand list")))
return failure();
}
// Check for the optional list of attributes.
SmallVector<ast::NamedAttributeDecl *> attributes;
if (consumeIf(Token::l_brace)) {
do {
FailureOr<ast::NamedAttributeDecl *> decl =
parseNamedAttributeDecl(opName);
if (failed(decl))
return failure();
attributes.emplace_back(*decl);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_brace,
"expected `}` after operation attribute list")))
return failure();
}
// Handle the result types of the operation.
SmallVector<ast::Expr *> resultTypes;
OpResultTypeContext resultTypeContext = inputResultTypeContext;
// Check for an explicit list of result types.
if (consumeIf(Token::arrow)) {
if (failed(parseToken(Token::l_paren,
"expected `(` before operation result type list")))
return failure();
// If result types are provided, initially assume that the operation does
// not rely on type inferrence. We don't assert that it isn't, because we
// may be inferring the value of some type/type range variables, but given
// that these variables may be defined in calls we can't always discern when
// this is the case.
resultTypeContext = OpResultTypeContext::Explicit;
// Handle the case of an empty result list.
if (!consumeIf(Token::r_paren)) {
do {
// Check for result signature code completion.
if (curToken.is(Token::code_complete)) {
codeCompleteOperationResultsSignature(opName, resultTypes.size());
return failure();
}
FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
if (failed(resultTypeExpr))
return failure();
resultTypes.push_back(*resultTypeExpr);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren,
"expected `)` after operation result type list")))
return failure();
}
} else if (parserContext != ParserContext::Rewrite) {
// If the result list isn't specified and we are in a match context, define
// an inplace unconstrained result range corresponding to all of the results
// of the operation. This avoids treating zero results the same way as
// "unconstrained results".
resultTypes.push_back(createImplicitRangeVar(
ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
} else if (resultTypeContext == OpResultTypeContext::Explicit) {
// If the result list isn't specified and we are in a rewrite, try to infer
// them at runtime instead.
resultTypeContext = OpResultTypeContext::Interface;
}
return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
attributes, resultTypes);
}
FailureOr<ast::Expr *> Parser::parseTupleExpr() {
SMRange loc = curToken.getLoc();
consumeToken(Token::l_paren);
DenseMap<StringRef, SMRange> usedNames;
SmallVector<StringRef> elementNames;
SmallVector<ast::Expr *> elements;
if (curToken.isNot(Token::r_paren)) {
do {
// Check for the optional element name assignment before the value.
StringRef elementName;
if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
Token elementNameTok = curToken;
consumeToken();
// The element name is only present if followed by an `=`.
if (consumeIf(Token::equal)) {
elementName = elementNameTok.getSpelling();
// Check to see if this name is already used.
auto elementNameIt =
usedNames.try_emplace(elementName, elementNameTok.getLoc());
if (!elementNameIt.second) {
return emitErrorAndNote(
elementNameTok.getLoc(),
llvm::formatv("duplicate tuple element label `{0}`",
elementName),
elementNameIt.first->getSecond(),
"see previous label use here");
}
} else {
// Otherwise, we treat this as part of an expression so reset the
// lexer.
resetToken(elementNameTok.getLoc());
}
}
elementNames.push_back(elementName);
// Parse the tuple element value.
FailureOr<ast::Expr *> element = parseExpr();
if (failed(element))
return failure();
elements.push_back(*element);
} while (consumeIf(Token::comma));
}
loc.End = curToken.getEndLoc();
if (failed(
parseToken(Token::r_paren, "expected `)` after tuple element list")))
return failure();
return createTupleExpr(loc, elements, elementNames);
}
FailureOr<ast::Expr *> Parser::parseTypeExpr() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_type);
// If we aren't followed by a `<`, the `type` keyword is treated as a normal
// identifier.
if (!consumeIf(Token::less)) {
resetToken(loc);
return parseIdentifierExpr();
}
if (!curToken.isString())
return emitError("expected string literal containing MLIR type");
std::string attrExpr = curToken.getStringValue();
consumeToken();
loc.End = curToken.getEndLoc();
if (failed(parseToken(Token::greater, "expected `>` after type literal")))
return failure();
return ast::TypeExpr::create(ctx, loc, attrExpr);
}
FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
StringRef name = curToken.getSpelling();
SMRange nameLoc = curToken.getLoc();
consumeToken(Token::underscore);
// Underscore expressions require a constraint list.
if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
return failure();
// Parse the constraints for the expression.
SmallVector<ast::ConstraintRef> constraints;
if (failed(parseVariableDeclConstraintList(constraints)))
return failure();
ast::Type type;
if (failed(validateVariableConstraints(constraints, type)))
return failure();
return createInlineVariableExpr(type, name, nameLoc, constraints);
}
//===----------------------------------------------------------------------===//
// Stmts
FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
FailureOr<ast::Stmt *> stmt;
switch (curToken.getKind()) {
case Token::kw_erase:
stmt = parseEraseStmt();
break;
case Token::kw_let:
stmt = parseLetStmt();
break;
case Token::kw_replace:
stmt = parseReplaceStmt();
break;
case Token::kw_return:
stmt = parseReturnStmt();
break;
case Token::kw_rewrite:
stmt = parseRewriteStmt();
break;
default:
stmt = parseExpr();
break;
}
if (failed(stmt) ||
(expectTerminalSemicolon &&
failed(parseToken(Token::semicolon, "expected `;` after statement"))))
return failure();
return stmt;
}
FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
SMLoc startLoc = curToken.getStartLoc();
consumeToken(Token::l_brace);
// Push a new block scope and parse any nested statements.
pushDeclScope();
SmallVector<ast::Stmt *> statements;
while (curToken.isNot(Token::r_brace)) {
FailureOr<ast::Stmt *> statement = parseStmt();
if (failed(statement))
return popDeclScope(), failure();
statements.push_back(*statement);
}
popDeclScope();
// Consume the end brace.
SMRange location(startLoc, curToken.getEndLoc());
consumeToken(Token::r_brace);
return ast::CompoundStmt::create(ctx, location, statements);
}
FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
if (parserContext == ParserContext::Constraint)
return emitError("`erase` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_erase);
// Parse the root operation expression.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
return createEraseStmt(loc, *rootOp);
}
FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_let);
// Parse the name of the new variable.
SMRange varLoc = curToken.getLoc();
if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
// `_` is a reserved variable name.
if (curToken.is(Token::underscore)) {
return emitError(varLoc,
"`_` may only be used to define \"inline\" variables");
}
return emitError(varLoc,
"expected identifier after `let` to name a new variable");
}
StringRef varName = curToken.getSpelling();
consumeToken();
// Parse the optional set of constraints.
SmallVector<ast::ConstraintRef> constraints;
if (consumeIf(Token::colon) &&
failed(parseVariableDeclConstraintList(constraints)))
return failure();
// Parse the optional initializer expression.
ast::Expr *initializer = nullptr;
if (consumeIf(Token::equal)) {
FailureOr<ast::Expr *> initOrFailure = parseExpr();
if (failed(initOrFailure))
return failure();
initializer = *initOrFailure;
// Check that the constraints are compatible with having an initializer,
// e.g. type constraints cannot be used with initializers.
for (ast::ConstraintRef constraint : constraints) {
LogicalResult result =
TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
.Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
ast::ValueRangeConstraintDecl>([&](const auto *cst) {
if (auto *typeConstraintExpr = cst->getTypeExpr()) {
return this->emitError(
constraint.referenceLoc,
"type constraints are not permitted on variables with "
"initializers");
}
return success();
})
.Default(success());
if (failed(result))
return failure();
}
}
FailureOr<ast::VariableDecl *> varDecl =
createVariableDecl(varName, varLoc, initializer, constraints);
if (failed(varDecl))
return failure();
return ast::LetStmt::create(ctx, loc, *varDecl);
}
FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
if (parserContext == ParserContext::Constraint)
return emitError("`replace` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_replace);
// Parse the root operation expression.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
if (failed(
parseToken(Token::kw_with, "expected `with` after root operation")))
return failure();
// The replacement portion of this statement is within a rewrite context.
llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
// Parse the replacement values.
SmallVector<ast::Expr *> replValues;
if (consumeIf(Token::l_paren)) {
if (consumeIf(Token::r_paren)) {
return emitError(
loc, "expected at least one replacement value, consider using "
"`erase` if no replacement values are desired");
}
do {
FailureOr<ast::Expr *> replExpr = parseExpr();
if (failed(replExpr))
return failure();
replValues.emplace_back(*replExpr);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren,
"expected `)` after replacement values")))
return failure();
} else {
// Handle replacement with an operation uniquely, as the replacement
// operation supports type inferrence from the root operation.
FailureOr<ast::Expr *> replExpr;
if (curToken.is(Token::kw_op))
replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
else
replExpr = parseExpr();
if (failed(replExpr))
return failure();
replValues.emplace_back(*replExpr);
}
return createReplaceStmt(loc, *rootOp, replValues);
}
FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_return);
// Parse the result value.
FailureOr<ast::Expr *> resultExpr = parseExpr();
if (failed(resultExpr))
return failure();
return ast::ReturnStmt::create(ctx, loc, *resultExpr);
}
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
if (parserContext == ParserContext::Constraint)
return emitError("`rewrite` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_rewrite);
// Parse the root operation.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
return failure();
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` to start rewrite body");
// The rewrite body of this statement is within a rewrite context.
llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
if (failed(rewriteBody))
return failure();
// Verify the rewrite body.
for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
if (isa<ast::ReturnStmt>(stmt)) {
return emitError(stmt->getLoc(),
"`return` statements are only permitted within a "
"`Constraint` or `Rewrite` body");
}
}
return createRewriteStmt(loc, *rootOp, *rewriteBody);
}
//===----------------------------------------------------------------------===//
// Creation+Analysis
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Decls
ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
// Unwrap reference expressions.
if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
node = init->getDecl();
return dyn_cast<ast::CallableDecl>(node);
}
FailureOr<ast::PatternDecl *>
Parser::createPatternDecl(SMRange loc, const ast::Name *name,
const ParsedPatternMetadata &metadata,
ast::CompoundStmt *body) {
return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
metadata.hasBoundedRecursion, body);
}
ast::Type Parser::createUserConstraintRewriteResultType(
ArrayRef<ast::VariableDecl *> results) {
// Single result decls use the type of the single result.
if (results.size() == 1)
return results[0]->getType();
// Multiple results use a tuple type, with the types and names grabbed from
// the result variable decls.
auto resultTypes = llvm::map_range(
results, [&](const auto *result) { return result->getType(); });
auto resultNames = llvm::map_range(
results, [&](const auto *result) { return result->getName().getName(); });
return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
llvm::to_vector(resultNames));
}
template <typename T>
FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
ast::CompoundStmt *body) {
if (!body->getChildren().empty()) {
if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
ast::Expr *resultExpr = retStmt->getResultExpr();
// Process the result of the decl. If no explicit signature results
// were provided, check for return type inference. Otherwise, check that
// the return expression can be converted to the expected type.
if (results.empty())
resultType = resultExpr->getType();
else if (failed(convertExpressionTo(resultExpr, resultType)))
return failure();
else
retStmt->setResultExpr(resultExpr);
}
}
return T::createPDLL(ctx, name, arguments, results, body, resultType);
}
FailureOr<ast::VariableDecl *>
Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints) {
// The type of the variable, which is expected to be inferred by either a
// constraint or an initializer expression.
ast::Type type;
if (failed(validateVariableConstraints(constraints, type)))
return failure();
if (initializer) {
// Update the variable type based on the initializer, or try to convert the
// initializer to the existing type.
if (!type)
type = initializer->getType();
else if (ast::Type mergedType = type.refineWith(initializer->getType()))
type = mergedType;
else if (failed(convertExpressionTo(initializer, type)))
return failure();
// Otherwise, if there is no initializer check that the type has already
// been resolved from the constraint list.
} else if (!type) {
return emitErrorAndNote(
loc, "unable to infer type for variable `" + name + "`", loc,
"the type of a variable must be inferable from the constraint "
"list or the initializer");
}
// Constraint types cannot be used when defining variables.
if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
return emitError(
loc, llvm::formatv("unable to define variable of `{0}` type", type));
}
// Try to define a variable with the given name.
FailureOr<ast::VariableDecl *> varDecl =
defineVariableDecl(name, loc, type, initializer, constraints);
if (failed(varDecl))
return failure();
return *varDecl;
}
FailureOr<ast::VariableDecl *>
Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
const ast::ConstraintRef &constraint) {
ast::Type argType;
if (failed(validateVariableConstraint(constraint, argType)))
return failure();
return defineVariableDecl(name, loc, argType, constraint);
}
LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
ast::Type &inferredType) {
for (const ast::ConstraintRef &ref : constraints)
if (failed(validateVariableConstraint(ref, inferredType)))
return failure();
return success();
}
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
ast::Type &inferredType) {
ast::Type constraintType;
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
if (failed(validateTypeConstraintExpr(typeExpr)))
return failure();
}
constraintType = ast::AttributeType::get(ctx);
} else if (const auto *cst =
dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
constraintType = ast::OperationType::get(
ctx, cst->getName(), lookupODSOperation(cst->getName()));
} else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
constraintType = typeTy;
} else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
constraintType = typeRangeTy;
} else if (const auto *cst =
dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
if (failed(validateTypeConstraintExpr(typeExpr)))
return failure();
}
constraintType = valueTy;
} else if (const auto *cst =
dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
if (failed(validateTypeRangeConstraintExpr(typeExpr)))
return failure();
}
constraintType = valueRangeTy;
} else if (const auto *cst =
dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
if (inputs.size() != 1) {
return emitErrorAndNote(ref.referenceLoc,
"`Constraint`s applied via a variable constraint "
"list must take a single input, but got " +
Twine(inputs.size()),
cst->getLoc(),
"see definition of constraint here");
}
constraintType = inputs.front()->getType();
} else {
llvm_unreachable("unknown constraint type");
}
// Check that the constraint type is compatible with the current inferred
// type.
if (!inferredType) {
inferredType = constraintType;
} else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
inferredType = mergedTy;
} else {
return emitError(ref.referenceLoc,
llvm::formatv("constraint type `{0}` is incompatible "
"with the previously inferred type `{1}`",
constraintType, inferredType));
}
return success();
}
LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
ast::Type typeExprType = typeExpr->getType();
if (typeExprType != typeTy) {
return emitError(typeExpr->getLoc(),
"expected expression of `Type` in type constraint");
}
return success();
}
LogicalResult
Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
ast::Type typeExprType = typeExpr->getType();
if (typeExprType != typeRangeTy) {
return emitError(typeExpr->getLoc(),
"expected expression of `TypeRange` in type constraint");
}
return success();
}
//===----------------------------------------------------------------------===//
// Exprs
FailureOr<ast::CallExpr *>
Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
MutableArrayRef<ast::Expr *> arguments) {
ast::Type parentType = parentExpr->getType();
ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
if (!callableDecl) {
return emitError(loc,
llvm::formatv("expected a reference to a callable "
"`Constraint` or `Rewrite`, but got: `{0}`",
parentType));
}
if (parserContext == ParserContext::Rewrite) {
if (isa<ast::UserConstraintDecl>(callableDecl))
return emitError(
loc, "unable to invoke `Constraint` within a rewrite section");
} else if (isa<ast::UserRewriteDecl>(callableDecl)) {
return emitError(loc, "unable to invoke `Rewrite` within a match section");
}
// Verify the arguments of the call.
/// Handle size mismatch.
ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
if (callArgs.size() != arguments.size()) {
return emitErrorAndNote(
loc,
llvm::formatv("invalid number of arguments for {0} call; expected "
"{1}, but got {2}",
callableDecl->getCallableType(), callArgs.size(),
arguments.size()),
callableDecl->getLoc(),
llvm::formatv("see the definition of {0} here",
callableDecl->getName()->getName()));
}
/// Handle argument type mismatch.
auto attachDiagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("see the definition of `{0}` here",
callableDecl->getName()->getName()),
callableDecl->getLoc());
};
for (auto it : llvm::zip(callArgs, arguments)) {
if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
attachDiagFn)))
return failure();
}
return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
callableDecl->getResultType());
}
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
ast::Decl *decl) {
// Check the type of decl being referenced.
ast::Type declType;
if (isa<ast::ConstraintDecl>(decl))
declType = ast::ConstraintType::get(ctx);
else if (isa<ast::UserRewriteDecl>(decl))
declType = ast::RewriteType::get(ctx);
else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
declType = varDecl->getType();
else
return emitError(loc, "invalid reference to `" +
decl->getName()->getName() + "`");
return ast::DeclRefExpr::create(ctx, loc, decl, declType);
}
FailureOr<ast::DeclRefExpr *>
Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
ArrayRef<ast::ConstraintRef> constraints) {
FailureOr<ast::VariableDecl *> decl =
defineVariableDecl(name, loc, type, constraints);
if (failed(decl))
return failure();
return ast::DeclRefExpr::create(ctx, loc, *decl, type);
}
FailureOr<ast::MemberAccessExpr *>
Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
SMRange loc) {
// Validate the member name for the given parent expression.
FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
if (failed(memberType))
return failure();
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
}
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
return valueRangeTy;
// Verify member access based on the operation type.
if (const ods::Operation *odsOp = opType.getODSOperation()) {
auto results = odsOp->getResults();
// Handle indexed results.
unsigned index = 0;
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
index < results.size()) {
return results[index].isVariadic() ? valueRangeTy : valueTy;
}
// Handle named results.
const auto *it = llvm::find_if(results, [&](const auto &result) {
return result.getName() == name;
});
if (it != results.end())
return it->isVariadic() ? valueRangeTy : valueTy;
} else if (llvm::isDigit(name[0])) {
// Allow unchecked numeric indexing of the results of unregistered
// operations. It returns a single value.
return valueTy;
}
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
// Handle indexed results.
unsigned index = 0;
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
index < tupleType.size()) {
return tupleType.getElementTypes()[index];
}
// Handle named results.
auto elementNames = tupleType.getElementNames();
const auto *it = llvm::find(elementNames, name);
if (it != elementNames.end())
return tupleType.getElementTypes()[it - elementNames.begin()];
}
return emitError(
loc,
llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
name, parentType));
}
FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
SMRange loc, const ast::OpNameDecl *name,
OpResultTypeContext resultTypeContext,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
SmallVectorImpl<ast::Expr *> &results) {
Optional<StringRef> opNameRef = name->getName();
const ods::Operation *odsOp = lookupODSOperation(opNameRef);
// Verify the inputs operands.
if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
return failure();
// Verify the attribute list.
for (ast::NamedAttributeDecl *attr : attributes) {
// Check for an attribute type, or a type awaiting resolution.
ast::Type attrType = attr->getValue()->getType();
if (!attrType.isa<ast::AttributeType>()) {
return emitError(
attr->getValue()->getLoc(),
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
}
}
assert(
(resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
"unexpected inferrence when results were explicitly specified");
// If we aren't relying on type inferrence, or explicit results were provided,
// validate them.
if (resultTypeContext == OpResultTypeContext::Explicit) {
if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
return failure();
// Validate the use of interface based type inferrence for this operation.
} else if (resultTypeContext == OpResultTypeContext::Interface) {
assert(opNameRef &&
"expected valid operation name when inferring operation results");
checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
}
return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
attributes);
}
LogicalResult
Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
SmallVectorImpl<ast::Expr *> &operands) {
return validateOperationOperandsOrResults(
"operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
valueRangeTy);
}
LogicalResult
Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
SmallVectorImpl<ast::Expr *> &results) {
return validateOperationOperandsOrResults(
"result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
}
void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
const ods::Operation *odsOp) {
// If the operation might not have inferrence support, emit a warning to the
// user. We don't emit an error because the interface might be added to the
// operation at runtime. It's rare, but it could still happen. We emit a
// warning here instead.
// Handle inferrence warnings for unknown operations.
if (!odsOp) {
ctx.getDiagEngine().emitWarning(
loc, llvm::formatv(
"operation result types are marked to be inferred, but "
"`{0}` is unknown. Ensure that `{0}` supports zero "
"results or implements `InferTypeOpInterface`. Include "
"the ODS definition of this operation to remove this warning.",
opName));
return;
}
// Handle inferrence warnings for known operations that expected at least one
// result, but don't have inference support. An elided results list can mean
// "zero-results", and we don't want to warn when that is the expected
// behavior.
bool requiresInferrence =
llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
return !result.isVariableLength();
});
if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
loc,
llvm::formatv("operation result types are marked to be inferred, but "
"`{0}` does not provide an implementation of "
"`InferTypeOpInterface`. Ensure that `{0}` attaches "
"`InferTypeOpInterface` at runtime, or add support to "
"the ODS definition to remove this warning.",
opName));
diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
odsOp->getLoc());
return;
}
}
LogicalResult Parser::validateOperationOperandsOrResults(
StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
ast::RangeType rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
return failure();
return success();
}
/// If the operation has ODS information, we can more accurately verify the
/// values.
if (odsOpLoc) {
auto emitSizeMismatchError = [&] {
return emitErrorAndNote(
loc,
llvm::formatv("invalid number of {0} groups for `{1}`; expected "
"{2}, but got {3}",
groupName, *name, odsValues.size(), values.size()),
*odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
};
// Handle the case where no values were provided.
if (values.empty()) {
// If we don't expect any on the ODS side, we are done.
if (odsValues.empty())
return success();
// If we do, check if we actually need to provide values (i.e. if any of
// the values are actually required).
unsigned numVariadic = 0;
for (const auto &odsValue : odsValues) {
if (!odsValue.isVariableLength())
return emitSizeMismatchError();
++numVariadic;
}
// If we are in a non-rewrite context, we don't need to do anything more.
// Zero-values is a valid constraint on the operation.
if (parserContext != ParserContext::Rewrite)
return success();
// Otherwise, when in a rewrite we may need to provide values to match the
// ODS signature of the operation to create.
// If we only have one variadic value, just use an empty list.
if (numVariadic == 1)
return success();
// Otherwise, create dummy values for each of the entries so that we
// adhere to the ODS signature.
for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
values.push_back(ast::RangeExpr::create(
ctx, loc, /*elements=*/std::nullopt, rangeTy));
}
return success();
}
// Verify that the number of values provided matches the number of value
// groups ODS expects.
if (odsValues.size() != values.size())
return emitSizeMismatchError();
auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
*odsOpLoc);
};
for (unsigned i = 0, e = values.size(); i < e; ++i) {
ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
return failure();
}
return success();
}
// Otherwise, accept the value groups as they have been defined and just
// ensure they are one of the expected types.
for (ast::Expr *&valueExpr : values) {
ast::Type valueExprType = valueExpr->getType();
// Check if this is one of the expected types.
if (valueExprType == rangeTy || valueExprType == singleTy)
continue;
// If the operand is an Operation, allow converting to a Value or
// ValueRange. This situations arises quite often with nested operation
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
if (singleTy == valueTy) {
if (valueExprType.isa<ast::OperationType>()) {
valueExpr = convertOpToValue(valueExpr);
continue;
}
}
// Otherwise, try to convert the expression to a range.
if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
continue;
return emitError(
valueExpr->getLoc(),
llvm::formatv(
"expected `{0}` or `{1}` convertible expression, but got `{2}`",
singleTy, rangeTy, valueExprType));
}
return success();
}
FailureOr<ast::TupleExpr *>
Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames) {
for (const ast::Expr *element : elements) {
ast::Type eleTy = element->getType();
if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
return emitError(
element->getLoc(),
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
}
}
return ast::TupleExpr::create(ctx, loc, elements, elementNames);
}
//===----------------------------------------------------------------------===//
// Stmts
FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
ast::Expr *rootOp) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>())
return emitError(rootOp->getLoc(), "expected `Op` expression");
return ast::EraseStmt::create(ctx, loc, rootOp);
}
FailureOr<ast::ReplaceStmt *>
Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
}
// If there are multiple replacement values, we implicitly convert any Op
// expressions to the value form.
bool shouldConvertOpToValues = replValues.size() > 1;
for (ast::Expr *&replExpr : replValues) {
ast::Type replType = replExpr->getType();
// Check that replExpr is an Operation, Value, or ValueRange.
if (replType.isa<ast::OperationType>()) {
if (shouldConvertOpToValues)
replExpr = convertOpToValue(replExpr);
continue;
}
if (replType != valueTy && replType != valueRangeTy) {
return emitError(replExpr->getLoc(),
llvm::formatv("expected `Op`, `Value` or `ValueRange` "
"expression, but got `{0}`",
replType));
}
}
return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
}
FailureOr<ast::RewriteStmt *>
Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
}
return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
}
//===----------------------------------------------------------------------===//
// Code Completion
//===----------------------------------------------------------------------===//
LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
ast::Type parentType = parentExpr->getType();
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
codeCompleteContext->codeCompleteOperationMemberAccess(opType);
else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
return failure();
}
LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
if (opName)
codeCompleteContext->codeCompleteOperationAttributeName(*opName);
return failure();
}
LogicalResult
Parser::codeCompleteConstraintName(ast::Type inferredType,
bool allowInlineTypeConstraints) {
codeCompleteContext->codeCompleteConstraintName(
inferredType, allowInlineTypeConstraints, curDeclScope);
return failure();
}
LogicalResult Parser::codeCompleteDialectName() {
codeCompleteContext->codeCompleteDialectName();
return failure();
}
LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
codeCompleteContext->codeCompleteOperationName(dialectName);
return failure();
}
LogicalResult Parser::codeCompletePatternMetadata() {
codeCompleteContext->codeCompletePatternMetadata();
return failure();
}
LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
codeCompleteContext->codeCompleteIncludeFilename(curPath);
return failure();
}
void Parser::codeCompleteCallSignature(ast::Node *parent,
unsigned currentNumArgs) {
ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
if (!callableDecl)
return;
codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
}
void Parser::codeCompleteOperationOperandsSignature(
Optional<StringRef> opName, unsigned currentNumOperands) {
codeCompleteContext->codeCompleteOperationOperandsSignature(
opName, currentNumOperands);
}
void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
unsigned currentNumResults) {
codeCompleteContext->codeCompleteOperationResultsSignature(opName,
currentNumResults);
}
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
FailureOr<ast::Module *>
mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
bool enableDocumentation,
CodeCompleteContext *codeCompleteContext) {
Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
return parser.parseModule();
}