Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 551e2123bc | |||
| a2b984c342 | |||
| a795c0e8d6 | |||
| bb6a45732f |
3
Makefile
3
Makefile
@@ -2,8 +2,9 @@
|
||||
|
||||
run:
|
||||
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||
|
||||
bench-select:
|
||||
go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||
go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||
|
||||
test:
|
||||
go test ./playground
|
||||
|
||||
@@ -76,8 +76,18 @@ func generate(scheamPath, outDir string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(")")
|
||||
|
||||
sb.WriteString(`
|
||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||
t := pgm.Table{
|
||||
Name: tblName,
|
||||
DerivedTable: fromQry,
|
||||
}
|
||||
return t
|
||||
}`)
|
||||
|
||||
// Format code before saving
|
||||
code, err := formatGoCode(sb.String())
|
||||
if err != nil {
|
||||
|
||||
27
pgm.go
27
pgm.go
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -106,3 +107,29 @@ func PgTime(t time.Time) pgtype.Timestamptz {
|
||||
func PgTimeNow() pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||
}
|
||||
|
||||
func TsAndQuery(q string) string {
|
||||
return strings.Join(strings.Fields(q), " & ")
|
||||
}
|
||||
|
||||
func TsPrefixAndQuery(q string) string {
|
||||
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
||||
}
|
||||
|
||||
func TsOrQuery(q string) string {
|
||||
return strings.Join(strings.Fields(q), " | ")
|
||||
}
|
||||
|
||||
func TsPrefixOrQuery(q string) string {
|
||||
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
||||
}
|
||||
|
||||
func fieldsWithSufix(v, sufix string) []string {
|
||||
fields := strings.Fields(v)
|
||||
prefixed := make([]string, len(fields))
|
||||
for i, f := range fields {
|
||||
prefixed[i] = f + sufix
|
||||
}
|
||||
|
||||
return prefixed
|
||||
}
|
||||
|
||||
29
pgm_field.go
29
pgm_field.go
@@ -18,6 +18,18 @@ func (f Field) Count() Field {
|
||||
return Field("COUNT(" + f.String() + ")")
|
||||
}
|
||||
|
||||
func ConcatWs(sep string, fields ...Field) Field {
|
||||
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
||||
}
|
||||
|
||||
func StringAgg(exp, sep string) Field {
|
||||
return Field("string_agg(" + exp + ",'" + sep + "')")
|
||||
}
|
||||
|
||||
func StringAggCast(exp, sep string) Field {
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
||||
}
|
||||
|
||||
// StringEscape will wrap field with:
|
||||
//
|
||||
// COALESCE(field, ”)
|
||||
@@ -141,6 +153,10 @@ func (f Field) DateTrunc(level, as string) Field {
|
||||
return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as)
|
||||
}
|
||||
|
||||
func (f Field) TsRank(fieldName, as string) Field {
|
||||
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
||||
}
|
||||
|
||||
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
||||
func (f Field) EqFold(val string) Conditioner {
|
||||
col := f.String()
|
||||
@@ -210,16 +226,9 @@ func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
||||
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
||||
}
|
||||
|
||||
func ConcatWs(sep string, fields ...Field) Field {
|
||||
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
||||
}
|
||||
|
||||
func StringAgg(exp, sep string) Field {
|
||||
return Field("string_agg(" + exp + ",'" + sep + "')")
|
||||
}
|
||||
|
||||
func StringAggCast(exp, sep string) Field {
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
||||
func (f Field) TsQuery(as string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
|
||||
}
|
||||
|
||||
func joinFileds(fields []Field) string {
|
||||
|
||||
15
pgm_table.go
15
pgm_table.go
@@ -2,6 +2,8 @@ package pgm
|
||||
|
||||
// Table in database
|
||||
type Table struct {
|
||||
textSearch *textSearchCTE
|
||||
|
||||
Name string
|
||||
DerivedTable Query
|
||||
PK []string
|
||||
@@ -9,6 +11,13 @@ type Table struct {
|
||||
debug bool
|
||||
}
|
||||
|
||||
// text search Common Table Expression
|
||||
type textSearchCTE struct {
|
||||
name string
|
||||
value string
|
||||
alias string
|
||||
}
|
||||
|
||||
// Debug when set true will print generated query string in stdout
|
||||
func (t *Table) Debug() Clause {
|
||||
t.debug = true
|
||||
@@ -31,11 +40,17 @@ func (t *Table) Insert() InsertClause {
|
||||
return qb
|
||||
}
|
||||
|
||||
func (t *Table) WithTextSearch(name, alias, textToSearch string) *Table {
|
||||
t.textSearch = &textSearchCTE{name: name, value: textToSearch, alias: alias}
|
||||
return t
|
||||
}
|
||||
|
||||
// Select table statement
|
||||
func (t *Table) Select(field ...Field) SelectClause {
|
||||
qb := &selectQry{
|
||||
debug: t.debug,
|
||||
fields: field,
|
||||
textSearch: t.textSearch,
|
||||
}
|
||||
|
||||
if t.DerivedTable != nil {
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package db
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||
t := pgm.Table{
|
||||
Name: tblName,
|
||||
DerivedTable: fromQry,
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -5,10 +5,18 @@ package db
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
var (
|
||||
User = pgm.Table{Name: "users", FieldCount: 11}
|
||||
User = pgm.Table{Name: "users", FieldCount: 12}
|
||||
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
||||
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
||||
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
||||
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
||||
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
||||
)
|
||||
|
||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||
t := pgm.Table{
|
||||
Name: tblName,
|
||||
DerivedTable: fromQry,
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ const (
|
||||
StatusID pgm.Field = "users.status_id"
|
||||
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
||||
MfaKind pgm.Field = "users.mfa_kind"
|
||||
// SearchVector field has db type "tsvector"
|
||||
SearchVector pgm.Field = "users.search_vector"
|
||||
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||
CreatedAt pgm.Field = "users.created_at"
|
||||
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestInsertQuery(t *testing.T) {
|
||||
Returning(user.ID).
|
||||
String()
|
||||
|
||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id"
|
||||
if got != expected {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
|
||||
@@ -60,27 +60,6 @@ func TestSelectWithHaving(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectDerived(t *testing.T) {
|
||||
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||
" ORDER BY t.status_id, t.created_at DESC"
|
||||
|
||||
qry := db.User.
|
||||
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||
Where(user.StatusID.Eq(1))
|
||||
|
||||
tbl := db.DerivedTable("t", qry)
|
||||
got := tbl.
|
||||
Select(tbl.Field("*")).
|
||||
Where(tbl.Field("rn").Lte(5)).
|
||||
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||
String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithJoin(t *testing.T) {
|
||||
got := db.User.Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
@@ -106,6 +85,50 @@ func TestSelectWithJoin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectDerived(t *testing.T) {
|
||||
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||
" ORDER BY t.status_id, t.created_at DESC"
|
||||
|
||||
qry := db.User.
|
||||
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||
Where(user.StatusID.Eq(1))
|
||||
|
||||
tbl := db.DerivedTable("t", qry)
|
||||
got := tbl.
|
||||
Select(tbl.Field("*")).
|
||||
Where(tbl.Field("rn").Lte(5)).
|
||||
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||
String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectTV(t *testing.T) {
|
||||
expected := "WITH ts AS (SELECT to_tsquery('english', $1) AS query)" +
|
||||
" SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, ts.query) AS rank" +
|
||||
" FROM users" +
|
||||
" JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||
" CROSS JOIN ts" +
|
||||
" WHERE users.status_id = $2 AND users.search_vector @@ ts.query" +
|
||||
" ORDER BY rank DESC"
|
||||
|
||||
qry := db.User.
|
||||
WithTextSearch("ts", "query", "text to search").
|
||||
Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("ts.query", "rank")).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("ts.query")).
|
||||
OrderBy(pgm.Field("rank").Desc())
|
||||
|
||||
got := qry.String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
for b.Loop() {
|
||||
|
||||
@@ -61,6 +61,7 @@ CREATE TABLE public.users (
|
||||
last_name character varying(50) NOT NULL,
|
||||
status_id smallint,
|
||||
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
||||
search_vector tsvector,
|
||||
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
@@ -55,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause {
|
||||
}
|
||||
|
||||
func (q *insertQry) Returning(field Field) First {
|
||||
col := field.Name()
|
||||
col := field.String()
|
||||
q.returing = &col
|
||||
return q
|
||||
}
|
||||
|
||||
@@ -135,6 +135,8 @@ type (
|
||||
}
|
||||
|
||||
selectQry struct {
|
||||
textSearch *textSearchCTE
|
||||
|
||||
table string
|
||||
fields []Field
|
||||
args []any
|
||||
@@ -313,6 +315,12 @@ func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
|
||||
|
||||
sb.Grow(q.averageLen())
|
||||
|
||||
if q.textSearch != nil {
|
||||
var ts = q.textSearch
|
||||
q.args = slices.Insert(q.args, 0, any(ts.value))
|
||||
sb.WriteString("WITH " + ts.name + " AS (SELECT to_tsquery('english', $1) AS " + ts.alias + ") ")
|
||||
}
|
||||
|
||||
// SELECT
|
||||
sb.WriteString("SELECT ")
|
||||
sb.WriteString(joinFileds(q.fields))
|
||||
@@ -323,6 +331,11 @@ func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
|
||||
sb.WriteString(" " + strings.Join(q.join, " "))
|
||||
}
|
||||
|
||||
// Search Query Cross join
|
||||
if q.textSearch != nil {
|
||||
sb.WriteString(" CROSS JOIN " + q.textSearch.name)
|
||||
}
|
||||
|
||||
// WHERE
|
||||
if len(q.where) > 0 {
|
||||
sb.WriteString(" WHERE ")
|
||||
|
||||
Reference in New Issue
Block a user