MLIR入门教程7-添加struct类型

2022-04-28 17:49:29 浏览数 (1)

本文翻译自MLIR社区官方文档,受限于笔者个人的认知水平,翻译效果可能不是很理想,翻译原始文档也会放在github上,供大家参考,如发现问题也欢迎提PR或者Issue:

https://github.com/hunterzju/llvm-tutorial

第7章:向Toy添加复合类型

在上一章中,我们演示了从Toy前端到LLVM IR的端到端编译流程。在本章中,我们将扩展Toy语言以支持新的复合struct类型。

在Toy中定义struct

我们需要定义的第一件事是用我们的“Toy”源语言定义这种类型的接口。Toy中struct类型的通用语法如下:

代码语言:javascript复制
# A struct is defined by using the `struct` keyword followed by a name.
struct MyStruct {
  # Inside of the struct is a list of variable declarations without initializers
  # or shapes, which may also be other previously defined structs.
  var a;
  var b;
}

现在,通过使用结构的名称而不是var,可以在函数中将结构用作变量或参数。该结构的成员通过.访问运算符进行访问。struct类型的值可以用复合初始值设定项初始化,也可以用{}括起来的其他初始值设定项的逗号分隔列表进行初始化。示例如下所示:

代码语言:javascript复制
struct Struct {
  var a;
  var b;
}

# User defined generic function may operate on struct types as well.
def multiply_transpose(Struct value) {
  # We can access the elements of a struct via the '.' operator.
  return transpose(value.a) * transpose(value.b);
}

def main() {
  # We initialize struct values using a composite initializer.
  Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};

  # We pass these arguments to functions like we do with variables.
  var c = multiply_transpose(value);
  print(c);
}

在MLIR中定义struct

在MLIR中,我们还需要结构类型的表示形式。MLIR不能提供完全符合我们需要的类型,因此我们需要定义自己的类型。我们将简单地将我们的struct定义为一组元素类型的未命名容器。struct的名称及其元素只对我们的toy编译器的AST有用,所以我们不需要在MLIR表示中对其进行编码。

定义类型类

定义类型类

如第2章中所述,MLIR中的Type对象是值类型的,并且依赖于拥有保存该类型的实际数据的内部存储对象。Type类本身充当内部TypeStorage对象的简单包装,该对象在MLIRContext的实例中是唯一的。在构造Type时,我们在内部只是构造并唯一化一个存储类的实例。

在定义包含参数数据的新Type时(例如struct类型,需要额外的信息来保存元素类型),我们需要提供派生的存储类。没有额外数据的singleton类型(如ef="https://zhuanlan.zhihu.com/LangRef.md#index-type">indextype)不需要存储类,使用默认的TypeStorage。

定义存储类

类型存储对象包含构造和唯一类型实例所需的所有数据。派生存储类必须继承自基本mlir::TypeStorage,并提供一组别名和钩子,供MLIRContext用于唯一类型。下面是我们的struct类型的存储实例的定义,每个必需的要求都内联了详细说明:

代码语言:javascript复制
/// This class represents the internal storage of the Toy `StructType`.
struct StructTypeStorage : public mlir::TypeStorage {
  /// The `KeyTy` is a required type that provides an interface for the storage
  /// instance. This type will be used when uniquing an instance of the type
  /// storage. For our struct type, we will unique each instance structurally on
  /// the elements that it contains.
  using KeyTy = llvm::ArrayRef<mlir::Type>;

  /// A constructor for the type storage instance.
  StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
      : elementTypes(elementTypes) {}

  /// Define the comparison function for the key type with the current storage
  /// instance. This is used when constructing a new instance to ensure that we
  /// haven't already uniqued an instance of the given key.
  bool operator==(const KeyTy &key) const { return key == elementTypes; }

  /// Define a hash function for the key type. This is used when uniquing
  /// instances of the storage.
  /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
  /// have hash functions available, so we could just omit this entirely.
  static llvm::hash_code hashKey(const KeyTy &key) {
    return llvm::hash_value(key);
  }

  /// Define a construction function for the key type from a set of parameters.
  /// These parameters will be provided when constructing the storage instance
  /// itself, see the `StructType::get` method further below.
  /// Note: This method isn't necessary because KeyTy can be directly
  /// constructed with the given parameters.
  static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
    return KeyTy(elementTypes);
  }

  /// Define a construction method for creating a new instance of this storage.
  /// This method takes an instance of a storage allocator, and an instance of a
  /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
  /// allocations used to create the type storage and its internal.
  static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
                                      const KeyTy &key) {
    // Copy the elements from the provided `KeyTy` into the allocator.
    llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);

    // Allocate the storage instance and construct it.
    return new (allocator.allocate<StructTypeStorage>())
        StructTypeStorage(elementTypes);
  }

  /// The following field contains the element types of the struct.
  llvm::ArrayRef<mlir::Type> elementTypes;
};

定义类型类

定义存储类后,我们可以为用户可见的StructType类添加定义。这是我们将实际与之交互的类。

代码语言:javascript复制
/// This class defines the Toy struct type. It represents a collection of
/// element types. All derived types in MLIR must inherit from the CRTP class
/// 'Type::TypeBase'. It takes as template parameters the concrete type
/// (StructType), the base class to use (Type), and the storage class
/// (StructTypeStorage).
class StructType : public mlir::Type::TypeBase<StructType, mlir::Type,
                                               StructTypeStorage> {
public:
  /// Inherit some necessary constructors from 'TypeBase'.
  using Base::Base;

  /// Create an instance of a `StructType` with the given element types. There
  /// *must* be at least one element type.
  static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) {
    assert(!elementTypes.empty() && "expected at least 1 element type");

    // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
    // of this type. The first parameter is the context to unique in. The
    // parameters after are forwarded to the storage instance.
    mlir::MLIRContext *ctx = elementTypes.front().getContext();
    return Base::get(ctx, elementTypes);
  }

  /// Returns the element types of this struct type.
  llvm::ArrayRef<mlir::Type> getElementTypes() {
    // 'getImpl' returns a pointer to the internal storage instance.
    return getImpl()->elementTypes;
  }

  /// Returns the number of element type held by this struct.
  size_t getNumElementTypes() { return getElementTypes().size(); }
};

我们在Toy Dialect构造函数中注册此类型的方式与我们处理操作的方式类似:

代码语言:javascript复制
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
    : mlir::Dialect(getDialectNamespace(), ctx) {
  addTypes<StructType>();
}

有了这个,我们现在可以在从Toy生成MLIR时使用我们的StructType。有关更多详细信息,请参见Examples/toy/ch7/mlir/MLIRGen.cpp

解析和打印

此时,我们可以在MLIR生成和转换过程中使用我们的StructType,但不能输出或解析.mlir。为此,我们需要增加对StructType实例的解析和打印支持。这可以通过覆盖Toy Dialect上的parseTypeprintType方法来实现。

代码语言:javascript复制
class ToyDialect : public mlir::Dialect {
public:
  /// Parse an instance of a type registered to the toy dialect.
  mlir::Type parseType(mlir::DialectAsmParser &parser) const override;

  /// Print an instance of a type registered to the toy dialect.
  void printType(mlir::Type type,
                 mlir::DialectAsmPrinter &printer) const override;
};

这些方法采用允许轻松实现必要功能的高级解析器或打印类的实例。在开始实现之前,让我们先考虑一下打印的IR中的struct类型所需的语法。如MLIR语言参考中所述,方言类型通常表示为:!dialect-namespace<type-data>,在某些情况下可以使用漂亮的形式。我们的toy解析器和打印类的职责是提供type-data位。我们将我们的StructType定义为具有以下形式:

代码语言:javascript复制
struct-type ::= `struct` `<` type (`,` type)* `>`

解析

解析器的实现如下所示:

代码语言:javascript复制
/// Parse an instance of a type registered to the toy dialect.
mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
  // Parse a struct type in the following form:
  //   struct-type ::= `struct` `<` type (`,` type)* `>`

  // NOTE: All MLIR parser function return a ParseResult. This is a
  // specialization of LogicalResult that auto-converts to a `true` boolean
  // value on failure to allow for chaining, but may be used with explicit
  // `mlir::failed/mlir::succeeded` as desired.

  // Parse: `struct` `<`
  if (parser.parseKeyword("struct") || parser.parseLess())
    return Type();

  // Parse the element types of the struct.
  SmallVector<mlir::Type, 1> elementTypes;
  do {
    // Parse the current element type.
    llvm::SMLoc typeLoc = parser.getCurrentLocation();
    mlir::Type elementType;
    if (parser.parseType(elementType))
      return nullptr;

    // Check that the type is either a TensorType or another StructType.
    if (!elementType.isa<mlir::TensorType, StructType>()) {
      parser.emitError(typeLoc, "element type for a struct must either "
                                "be a TensorType or a StructType, got: ")
          << elementType;
      return Type();
    }
    elementTypes.push_back(elementType);

    // Parse the optional: `,`
  } while (succeeded(parser.parseOptionalComma()));

  // Parse: `>`
  if (parser.parseGreater())
    return Type();
  return StructType::get(elementTypes);
}

打印

打印类的实现如下所示:

代码语言:javascript复制
/// Print an instance of a type registered to the toy dialect.
void ToyDialect::printType(mlir::Type type,
                           mlir::DialectAsmPrinter &printer) const {
  // Currently the only toy type is a struct type.
  StructType structType = type.cast<StructType>();

  // Print the struct type according to the parser format.
  printer << "struct<";
  llvm::interleaveComma(structType.getElementTypes(), printer);
  printer << '>';
}

在继续之前,让我们先来看一下展示我们现在拥有的功能的快速示例:

代码语言:javascript复制
struct Struct {
  var a;
  var b;
}

def multiply_transpose(Struct value) {
}

它会生成以下内容:

代码语言:javascript复制
module {
  func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
    toy.return
  }
}

StructType上操作

现在已经定义了struct类型,我们可以往返于IR之间。下一步是添加对在我们的操作中使用它的支持。

更新现有操作

我们现有的一些操作需要更新以处理StructType。第一步是让ODS框架知道我们的Type,这样我们就可以在操作定义中使用它。下面是一个简单的示例:

代码语言:javascript复制
// Provide a definition for the Toy StructType for use in ODS. This allows for
// using StructType in a similar way to Tensor or MemRef.
def Toy_StructType :
    Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">;

// Provide a definition of the types that are used within the Toy dialect.
def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;

然后我们可以更新我们的操作,例如ReturnOp,以也接受Toy_StructType

代码语言:javascript复制
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
  ...
  let arguments = (ins Variadic<Toy_Type>:$input);
  ...
}

添加新的TOY操作

除了现有的操作之外,我们还将添加一些新的操作,这些操作将提供对structs的更具体的处理。

toy.struct_constant

这个新操作实现了结构的常量值。在我们当前的建模中,我们只使用了一个数组属性,它为每个struct元素包含一组常量值。

代码语言:javascript复制
%0 = toy.struct_constant [
    dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>
  ] : !toy.struct<tensor<*xf64>>

toy.struct_access

这个新操作实现了struct值的第N个元素。

代码语言:javascript复制
// Using %0 from above
  %1 = toy.struct_access %0[0] : !toy.struct<tensor<*xf64>> -> tensor<*xf64>

通过这些操作,我们可以重新查看最初的示例:

代码语言:javascript复制
struct Struct {
  var a;
  var b;
}

# User defined generic function may operate on struct types as well.
def multiply_transpose(Struct value) {
  # We can access the elements of a struct via the '.' operator.
  return transpose(value.a) * transpose(value.b);
}

def main() {
  # We initialize struct values using a composite initializer.
  Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};

  # We pass these arguments to functions like we do with variables.
  var c = multiply_transpose(value);
  print(c);
}

并最终获得完整的MLIR模块:

代码语言:javascript复制
module {
  func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
    %0 = toy.struct_access %arg0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
    %1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
    %2 = toy.struct_access %arg0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
    %3 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
    %4 = toy.mul %1, %3 : tensor<*xf64>
    toy.return %4 : tensor<*xf64>
  }
  func @main() {
    %0 = toy.struct_constant [
      dense<[[1.000000e 00, 2.000000e 00, 3.000000e 00], [4.000000e 00, 5.000000e 00, 6.000000e 00]]> : tensor<2x3xf64>,
      dense<[[1.000000e 00, 2.000000e 00, 3.000000e 00], [4.000000e 00, 5.000000e 00, 6.000000e 00]]> : tensor<2x3xf64>
    ] : !toy.struct<tensor<*xf64>, tensor<*xf64>>
    %1 = toy.generic_call @multiply_transpose(%0) : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
    toy.print %1 : tensor<*xf64>
    toy.return
  }
}

优化StructType的操作

现在我们有几个操作在“StructType”上,我们也有许多新的常量折叠机会。

内联后,上一节中的MLIR模块如下所示:

代码语言:javascript复制
module {
  func @main() {
    %0 = toy.struct_constant [
      dense<[[1.000000e 00, 2.000000e 00, 3.000000e 00], [4.000000e 00, 5.000000e 00, 6.000000e 00]]> : tensor<2x3xf64>,
      dense<[[1.000000e 00, 2.000000e 00, 3.000000e 00], [4.000000e 00, 5.000000e 00, 6.000000e 00]]> : tensor<2x3xf64>
    ] : !toy.struct<tensor<*xf64>, tensor<*xf64>>
    %1 = toy.struct_access %0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
    %2 = toy.transpose(%1 : tensor<*xf64>) to tensor<*xf64>
    %3 = toy.struct_access %0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
    %4 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
    %5 = toy.mul %2, %4 : tensor<*xf64>
    toy.print %5 : tensor<*xf64>
    toy.return
  }
}

我们有几个访问toy.struct_constanttoy.struct_access操作。如第3章(FoldConstantReshape)所述,我们可以通过在操作定义上设置hasFolder位并提供*Op::fold方法的定义来为这些toy操作添加folder操作。

代码语言:javascript复制
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }

/// Fold struct constants.
OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
  return value();
}

/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
  auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
  if (!structAttr)
    return nullptr;

  size_t elementIndex = index().getZExtValue();
  return structAttr[elementIndex];
}

为了确保MLIR在折叠我们的Toy操作时生成正确的常量操作,即TensorTypeConstantOpStructTypeStructConstant,我们需要提供方言钩子MaterializeConstant的覆盖。这允许通用MLIR操作在必要时为TOY方言创建常量。

代码语言:javascript复制
mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
                                                 mlir::Attribute value,
                                                 mlir::Type type,
                                                 mlir::Location loc) {
  if (type.isa<StructType>())
    return builder.create<StructConstantOp>(loc, type,
                                            value.cast<mlir::ArrayAttr>());
  return builder.create<ConstantOp>(loc, type,
                                    value.cast<mlir::DenseElementsAttr>());
}

有了这一点,我们现在可以生成可以生成到LLVM的代码,而不需要对我们的流程进行任何更改。

代码语言:javascript复制
module {
  func @main() {
    %0 = toy.constant dense<[[1.000000e 00, 2.000000e 00, 3.000000e 00], [4.000000e 00, 5.000000e 00, 6.000000e 00]]> : tensor<2x3xf64>
    %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
    %2 = toy.mul %1, %1 : tensor<3x2xf64>
    toy.print %2 : tensor<3x2xf64>
    toy.return
  }
}

您可以构建toyc-ch7并亲自试用:toyc-ch7 test/examples/Toy/ch7/struct-codegen.toy -emit=mlir。有关定义自定义类型的更多详细信息,请参阅DefiningAttributesAndTypes.

0 人点赞