llvm-project/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp

46 lines
1.6 KiB
C++

//===- BuildOnlyExtensionTest.cpp - unit test for transform extensions ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::transform;
namespace {
class Extension : public TransformDialectExtension<Extension> {
public:
using Base::Base;
void init() { declareGeneratedDialect<func::FuncDialect>(); }
};
} // end namespace
TEST(BuildOnlyExtensionTest, buildOnlyExtension) {
// Register the build-only version of the transform dialect extension. The
// func dialect is declared as generated so it should not be loaded along with
// the transform dialect.
DialectRegistry registry;
registry.addExtensions<BuildOnly<Extension>>();
MLIRContext ctx(registry);
ctx.getOrLoadDialect<TransformDialect>();
ASSERT_FALSE(ctx.getLoadedDialect<func::FuncDialect>());
}
TEST(BuildOnlyExtensionTest, buildAndApplyExtension) {
// Register the full version of the transform dialect extension. The func
// dialect should be loaded along with the transform dialect.
DialectRegistry registry;
registry.addExtensions<Extension>();
MLIRContext ctx(registry);
ctx.getOrLoadDialect<TransformDialect>();
ASSERT_TRUE(ctx.getLoadedDialect<func::FuncDialect>());
}