go generate の使い方をメモ程度に残しておきます。

既存のRepositoryモデルを元に、RepositoryモデルをDomainモデルに変換するマッパーを生成するようなケースを紹介します。

フォルダ構成

├── domains
│   └── model.go                      # Domainモデル
└── repositories
    ├── gen
    │   ├── gen.go                    # generator本体
    │   └── mapper.tmpl               # 参照するテンプレート
    ├── model.go                      # 参照するRepositoryモデル
    └── mapper.gen.go                 # 自動生成されたマッパー

genフォルダ配下にgenerator本体を置くのは、go generate のベストプラクティス を参考にさせていただきました。

参照元のコード

Repositoryモデル

今回はbunを使ってます。

import (
    "github.com/uptrace/bun"
    "time"
)

type DbComment struct {
    bun.BaseModel `bun:"table:comment,alias:c"`

    Id        *int64     `bun:"id,pk"`
    Text      string     `bun:"text"`
    CreatedAt *time.Time `bun:"created_at"`
    UpdatedAt *time.Time `bun:"updated_at"`
}

自動生成

動かしてみる

実際にファイルを生成したいフォルダ配下にgenというフォルダを生成して、そこにgen.goを作成します。

repositories/gen/gen.go
まずは、以下のような内容を記載してみます。

//go:generate go run .
//go:generate gofmt -w ../

import (
    "fmt"
)

func main() {
    fmt.Println("generate!!!")
}

go generateは、実行フォルダ配下で、//go:generate xxxのようなコメントが存在する場合、xxxの部分をコマンドとして実行してくれます。上記の例では

  • 1行目は、go run .が実行され、具体的にはmain()が実行されます
  • 2行目は、gofmt -w ../が実行され、親フォルダ配下(自動生成したソースコードを含む)をフォーマットします。

この時点で、以下のコマンドを実行すると、main()が実行されたことが分かると思います。

cd repositories/gen
go generate
> generate!!!

つまり、あとはmain()の中で、

  • RepositoryモデルのASTを読み込んで必要な情報を取得する
  • Domainモデルとのマッパーファイルを作成する
    • テンプレートを準備しておいてそこに流し込む

という作業をやってあげれば良さそうです。

gen.goを実装する

細々説明するほど難しい実装でもないので、全体をそのまま載っけておきます。
main()から辿っていけば何をやっているかわかると思います。

repositories/gen/gen.go

//go:generate go run .
//go:generate gofmt -w ../

package main

import (
    "bytes"
    "fmt"
    "go/ast"
    "go/parser"
    "go/printer"
    "go/token"
    "log"
    "os"
    "strings"
    "text/template"
)

// 構造体定義
type StructDef struct {
    Name   string
    Fields []FieldDef
}

func (sd StructDef) Alias() string {
    return strings.ToLower(sd.Name[0:1]) + sd.Name[1:]
}

// 構造体フィールド定義
type FieldDef struct {
    Name string
    Type string
}

func (fd FieldDef) Alias() string {
    return strings.ToLower(fd.Name[0:1]) + fd.Name[1:]
}

// ファイルをパースしてASTから必要な構造体情報を返却
func ParseFirstStruct(fpath string) ([]StructDef, error) {
    fset := token.NewFileSet()
    f, err := parser.ParseFile(fset, fpath, nil, 0)
    if err != nil {
        return nil, err
    }

    list := []StructDef{}
    ast.Inspect(f, func(n ast.Node) bool {
        x, ok := n.(*ast.TypeSpec)
        if !ok {
            return true
        }
        if y, ok := x.Type.(*ast.StructType); ok {
            sdef := StructDef{}
            sdef.Name = x.Name.Name
            for _, fld := range y.Fields.List {
                if fld.Names == nil {
                    continue
                }
                var typeNameBuf bytes.Buffer
                err := printer.Fprint(&typeNameBuf, fset, fld.Type)
                if err != nil {
                    log.Fatalf("failed printing %s", err)
                }
                sdef.Fields = append(sdef.Fields, FieldDef{Name: fld.Names[0].Name, Type: typeNameBuf.String()})
            }
            list = append(list, sdef)
        }
        return true
    })
    return list, nil
}

// マッパーファイルを生成
func createMapperFile(outputFilePath string, def StructDef) error {
    file, err := os.Create(outputFilePath)
    if err != nil {
        return err
    }
    defer file.Close()

    t := template.Must(template.ParseFiles("./mapper.tmpl"))
    data := map[string]interface{}{
        "ModelName":  def.Name[2:], // "Comment"
        "ModelArias": strings.ToLower(def.Name[2:][0:1]) + def.Name[2:][1:], // "comment"
        "Fields":     def.Fields,
    }
    if err := t.Execute(file, data); err != nil {
        return err
    }
    fmt.Println(outputFilePath + " is generated.")
    return nil
}

// エントリーポイント
func main() {
    inputFilePath := "../model.go"

    list, err := ParseFirstStruct(inputFilePath)
    if err != nil || len(list) == 0 {
        fmt.Fprintf(os.Stderr, "model parse faild.\n: %s", err)
        os.Exit(1)
    }

    outputFilePath := "../" + strings.Split(strings.Split(inputFilePath, "/")[1], ".")[0] + "_mapper.gen.go"
    err = createMapperFile(outputFilePath, list[0])
    if err != nil {
        fmt.Fprintf(os.Stderr, "code generate failed.\n: %s", err)
        os.Exit(1)
    }
}

/repositories/gen/mapper.tmpl

// Code generated by go generate DO NOT EDIT.

package repositories

import domains "github.com/rinoguchi/microblog/test/domains"

func (db{{ .ModelName }} Db{{ .ModelName }}) ToDomain{{ .ModelName }}() domains.{{ .ModelName }} {
    return domains.{{ .ModelName }}{
    {{range .Fields -}}
    {{ .Name }}:  db{{ $.ModelName }}.{{ .Name }},
    {{end -}}
    }
}

func FromDomain{{ .ModelName }}({{ .ModelArias }} domains.{{ .ModelName }}) Db{{ .ModelName }} {
    return Db{{ .ModelName }}{
    {{range .Fields -}}
    {{ .Name }}:  {{ $.ModelArias }}.{{ .Name }},
    {{end -}}
    }
}

上記を実装した上で、改めてgo generateを実行すると、model_mapper.gen.goが生成されました。

go generate
../model_mapper.gen.go is generated.

/repositories/model_mapper.gen.go

// Code generated by go generate DO NOT EDIT.

package repositories

import domains "github.com/rinoguchi/microblog/test/domains"

func (dbComment DbComment) ToDomainComment() domains.Comment {
    return domains.Comment{
        Id:        dbComment.Id,
        Text:      dbComment.Text,
        CreatedAt: dbComment.CreatedAt,
        UpdatedAt: dbComment.UpdatedAt,
    }
}

func FromDomainComment(comment domains.Comment) DbComment {
    return DbComment{
        Id:        comment.Id,
        Text:      comment.Text,
        CreatedAt: comment.CreatedAt,
        UpdatedAt: comment.UpdatedAt,
    }
}

これにて終了です。

Posted in: go