From 37f50324fdc0023d3bd958917f994f6199a2c89a Mon Sep 17 00:00:00 2001 From: Jan Tytgat Date: Wed, 21 May 2025 21:04:53 +0200 Subject: [PATCH] Move go-sql-queryrepo package into go-kit/sqr Signed-off-by: Jan Tytgat --- go.mod | 2 +- go.sum | 2 + sqr/collection.go | 36 ++++++++ sqr/collection_test.go | 162 ++++++++++++++++++++++++++++++++++ sqr/loader.go | 30 +++++++ sqr/preparer.go | 76 ++++++++++++++++ sqr/repository.go | 191 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 498 insertions(+), 1 deletion(-) create mode 100644 sqr/collection.go create mode 100644 sqr/collection_test.go create mode 100644 sqr/loader.go create mode 100644 sqr/preparer.go create mode 100644 sqr/repository.go diff --git a/go.mod b/go.mod index 861c3a5..37e9303 100644 --- a/go.mod +++ b/go.mod @@ -12,5 +12,5 @@ require ( require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/samber/lo v1.50.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/text v0.25.0 // indirect ) diff --git a/go.sum b/go.sum index a2fc126..c68088e 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sqr/collection.go b/sqr/collection.go new file mode 100644 index 0000000..5109c24 --- /dev/null +++ b/sqr/collection.go @@ -0,0 +1,36 @@ +package queryrepo + +import ( + "fmt" +) + +// newCollection creates a new collection with the supplied name and returns it to the caller. +func newCollection(name string) collection { + return collection{ + name: name, + queries: make(map[string]string), + } +} + +type collection struct { + name string + queries map[string]string +} + +// add adds a query to the collection. +func (c *collection) add(name, query string) error { + if _, ok := c.queries[name]; ok { + return fmt.Errorf("query %s already exists", name) + } + c.queries[name] = query + return nil +} + +// get retrieves a query from the collection by name. +// If the query name cannot be found, get() returns an empty string and an error. +func (c *collection) get(name string) (string, error) { + if _, ok := c.queries[name]; !ok { + return "", fmt.Errorf("query %s not found in collection %s", name, c.name) + } + return c.queries[name], nil +} diff --git a/sqr/collection_test.go b/sqr/collection_test.go new file mode 100644 index 0000000..4a85bae --- /dev/null +++ b/sqr/collection_test.go @@ -0,0 +1,162 @@ +package queryrepo + +import ( + "reflect" + "testing" +) + +func Test_collection_add(t *testing.T) { + type fields struct { + name string + queries map[string]string + } + + type args struct { + name string + query string + } + + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "good", + fields: fields{ + name: "test1", + queries: map[string]string{ + "query1": "queryString1", + "query2": "queryString2", + }, + }, + args: args{ + name: "query3", + query: "queryString3", + }, + wantErr: false, + }, + { + name: "bad", + fields: fields{ + name: "test1", + queries: map[string]string{ + "query1": "queryString1", + "query2": "queryString2", + }, + }, + args: args{ + name: "query2", + query: "queryString2", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &collection{ + name: tt.fields.name, + queries: tt.fields.queries, + } + if err := c.add(tt.args.name, tt.args.query); (err != nil) != tt.wantErr { + t.Errorf("add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_collection_get(t *testing.T) { + type fields struct { + name string + queries map[string]string + } + type args struct { + name string + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + name: "good", + fields: fields{ + name: "test1", + queries: map[string]string{ + "query1": "queryString1", + "query2": "queryString2", + }, + }, + args: args{ + name: "query1", + }, + want: "queryString1", + wantErr: false, + }, + { + name: "bad", + fields: fields{ + name: "test1", + queries: map[string]string{ + "query1": "queryString1", + "query2": "queryString2", + }, + }, + args: args{ + name: "query3", + }, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &collection{ + name: tt.fields.name, + queries: tt.fields.queries, + } + got, err := c.get(tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("get() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_newCollection(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want collection + }{ + { + name: "test1", + args: args{ + name: "test1", + }, + want: collection{ + name: "test1", + queries: make(map[string]string), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newCollection(tt.args.name); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newCollection() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sqr/loader.go b/sqr/loader.go new file mode 100644 index 0000000..10245df --- /dev/null +++ b/sqr/loader.go @@ -0,0 +1,30 @@ +package queryrepo + +import ( + "embed" + "fmt" + "io/fs" + "path" + "path/filepath" +) + +// LoadQueryFromFs retrieves a query from a filesystem. +// It needs the root path to start the search from, as well as a collection name and a query name. +// The collection name equals to a direct directory name in the root path. +// The query name is the file name (without extension) to load the contents from. +// It returns and empty string and an error if the file cannot be found. +func LoadQueryFromFs(f fs.FS, rootPath, collectionName, queryName string) (string, error) { + var err error + var contents []byte + switch f.(type) { + case embed.FS: + if contents, err = fs.ReadFile(f, path.Join(rootPath, collectionName, queryName)+".sql"); err != nil { + return "", fmt.Errorf("failed to read file %s: %w", path.Join(rootPath, collectionName, queryName)+".sql", err) + } + default: + if contents, err = fs.ReadFile(f, filepath.Join(rootPath, collectionName, queryName)+".sql"); err != nil { + return "", fmt.Errorf("failed to read file %s: %w", filepath.Join(rootPath, collectionName, queryName)+".sql", err) + } + } + return string(contents), nil +} diff --git a/sqr/preparer.go b/sqr/preparer.go new file mode 100644 index 0000000..be23835 --- /dev/null +++ b/sqr/preparer.go @@ -0,0 +1,76 @@ +package queryrepo + +import ( + "context" + "database/sql" + "errors" + "io/fs" +) + +// Preparer defines the interface to create a prepared statement. +type Preparer interface { + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// Prepare creates a prepared statement for the supplied Preparer by looking up a query in the supplied repository. +// It returns an nil pointer and an error if either the query cannot be found in the supplied repository, or the statement preparation fails. +func Prepare[T Preparer](t T, r *Repository, collectionName, queryName string) (*sql.Stmt, error) { + if r == nil { + return nil, errors.New("repository is nil") + } + + var err error + var query string + + if query, err = r.Get(collectionName, queryName); err != nil { + return nil, err + } + return t.Prepare(query) +} + +// Prepare creates a prepared statement for the supplied Preparer by looking up a query in the supplied repository using a context. +// It returns an nil pointer and an error if either the query cannot be found in the supplied repository, or the statement preparation fails. +func PrepareContext[T Preparer](ctx context.Context, t T, r *Repository, collectionName, queryName string) (*sql.Stmt, error) { + if r == nil { + return nil, errors.New("repository is nil") + } + + var err error + var query string + + if query, err = r.Get(collectionName, queryName); err != nil { + return nil, err + } + return t.PrepareContext(ctx, query) +} + +// PrepareFromFs creates a prepared statement for the supplied Preparer by looking up a query in the supplied filesystem. +// It returns an nil pointer and an error if either the query cannot be found in the supplied filesystem, or the statement preparation fails. +func PrepareFromFs[T Preparer](t T, f fs.FS, rootPath, collectionName, queryName string) (*sql.Stmt, error) { + if f == nil { + return nil, errors.New("invalid filesystem") + } + var err error + var query string + + if query, err = LoadQueryFromFs(f, rootPath, collectionName, queryName); err != nil { + return nil, err + } + return t.Prepare(query) +} + +// PrepareFromFs creates a prepared statement for the supplied Preparer by looking up a query in the supplied filesystem using a context. +// It returns an nil pointer and an error if either the query cannot be found in the supplied filesystem, or the statement preparation fails. +func PrepareFromFsContext[T Preparer](ctx context.Context, t T, f fs.FS, rootPath, collectionName, queryName string) (*sql.Stmt, error) { + if f == nil { + return nil, errors.New("invalid filesystem") + } + var err error + var query string + + if query, err = LoadQueryFromFs(f, rootPath, collectionName, queryName); err != nil { + return nil, err + } + return t.PrepareContext(ctx, query) +} diff --git a/sqr/repository.go b/sqr/repository.go new file mode 100644 index 0000000..af8c763 --- /dev/null +++ b/sqr/repository.go @@ -0,0 +1,191 @@ +// Package queryrepo enables the use of centralized storage for all SQL queries used in an application. +package queryrepo + +import ( + "context" + "database/sql" + "embed" + "errors" + "fmt" + "io/fs" + "path" + "path/filepath" + "strings" + "sync" +) + +// NewFromFs creates a new repository using a filesystem. +// It takes a filesystem and a root path to start loading files from and returns an error if files cannot be loaded. +func NewFromFs(f fs.FS, rootPath string) (*Repository, error) { + repo := &Repository{ + queries: make(map[string]collection), + } + + return repo, loadFromFs(repo, f, rootPath) +} + +// A Repository stores multiple collections of queries in a map for later use. +// Queries can either be retrieved by their name, or be used to create a prepared statement. +type Repository struct { + queries map[string]collection + mux sync.Mutex +} + +// add adds the supplied collection to the repository. +// It returns an error if the collection already exists. +func (r *Repository) add(c collection) error { + r.mux.Lock() + defer r.mux.Unlock() + + if _, ok := r.queries[c.name]; ok { + return fmt.Errorf("collection %s already exists", c.name) + } + r.queries[c.name] = c + return nil +} + +// DbPrepare creates a prepared statement for the supplied database handle. +// It takes a collection name and query name to look up the query to create the prepared statement. +func (r *Repository) DbPrepare(db *sql.DB, collectionName, queryName string) (*sql.Stmt, error) { + if db == nil { + return nil, errors.New("db is nil") + } + + var err error + var query string + + if query, err = r.Get(collectionName, queryName); err != nil { + return nil, err + } + return db.Prepare(query) +} + +// DbPrepareContext creates a prepared statement for the supplied database handle using a context. +// It takes a collection name and query name to look up the query to create the prepared statement. +func (r *Repository) DbPrepareContext(ctx context.Context, db *sql.DB, collectionName, queryName string) (*sql.Stmt, error) { + if db == nil { + return nil, errors.New("db is nil") + } + + var err error + var query string + + if query, err = r.Get(collectionName, queryName); err != nil { + return nil, err + } + + return db.PrepareContext(ctx, query) +} + +// Get retrieves the supplied query from the repository. +// It takes a collection name and a query name to perform the lookup and returns an empty string and an error if the query cannot be found +// in the collection. +func (r *Repository) Get(collectionName, queryName string) (string, error) { + r.mux.Lock() + defer r.mux.Unlock() + + if s, ok := r.queries[collectionName]; ok { + return s.get(queryName) + } + return "", fmt.Errorf("collection %s not found", collectionName) +} + +// TxPrepare creates a prepared statement for the supplied in-progress database transaction. +// It takes a collection name and query name to look up the query to create the prepared statement. +func (r *Repository) TxPrepare(tx *sql.Tx, collectionName, queryName string) (*sql.Stmt, error) { + if tx == nil { + return nil, errors.New("tx is nil") + } + var err error + var statement string + + if statement, err = r.Get(collectionName, queryName); err != nil { + return nil, err + } + + return tx.Prepare(statement) +} + +// TxPrepare creates a prepared statement for the supplied in-progress database transaction using a context. +// It takes a collection name and query name to look up the query to create the prepared statement. +func (r *Repository) TxPrepareContext(ctx context.Context, tx *sql.Tx, collectionName, queryName string) (*sql.Stmt, error) { + if tx == nil { + return nil, errors.New("tx is nil") + } + var err error + var statement string + + if statement, err = r.Get(collectionName, queryName); err != nil { + return nil, err + } + + return tx.PrepareContext(ctx, statement) +} + +// loadFromFs looks for directories in the root path to create collections for. +// If a directory is found, it loads all the files in the subdirectory and adds the returned collection to the repository. +func loadFromFs(r *Repository, f fs.FS, rootPath string) error { + if r == nil { + return errors.New("repository is nil") + } + + if f == nil { + return errors.New("filesystem is nil") + } + + var err error + var files []fs.DirEntry + if files, err = fs.ReadDir(f, rootPath); err != nil { + return err + } + + for _, file := range files { + if file.IsDir() { + var c collection + if c, err = loadFilesFromDir(f, rootPath, file.Name()); err != nil { + return err + } + + if err = r.add(c); err != nil { + return err + } + } + } + return nil +} + +// loadFilesFromDir loads all the files in the directory and returns a collection of queries. +func loadFilesFromDir(f fs.FS, rootPath, dirName string) (collection, error) { + var err error + var c = newCollection(dirName) + var fullPath string + + switch f.(type) { + case embed.FS: + fullPath = path.Join(rootPath, dirName) + default: + fullPath = filepath.Join(rootPath, dirName) + + } + + var files []fs.DirEntry + if files, err = fs.ReadDir(f, fullPath); err != nil { + return c, err + } + + for _, file := range files { + if file.IsDir() { + return c, fmt.Errorf("nested directories are not supported, %s is a directory in %s", file.Name(), fullPath) + } + + var contents string + if contents, err = LoadQueryFromFs(f, rootPath, dirName, strings.TrimSuffix(file.Name(), filepath.Ext(file.Name()))); err != nil { + return c, err + } + + if err = c.add(strings.TrimSuffix(file.Name(), filepath.Ext(file.Name())), contents); err != nil { + return c, err + } + } + return c, nil +}