From 9837fb1e37414d78c7836cd77aba0de8d9afe76b Mon Sep 17 00:00:00 2001 From: Ankit Patial Date: Tue, 21 Oct 2025 00:27:00 +0530 Subject: [PATCH] derived table and row number addition --- Makefile | 2 ++ pgm_field.go | 37 +++++++++++++++++++++++++++++++++++ pgm_table.go | 22 ++++++++++++++++----- playground/db/derived.go | 11 +++++++++++ playground/qry_select_test.go | 21 ++++++++++++++++++++ qry_select.go | 26 ++++++++++++++++++++++-- 6 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 playground/db/derived.go diff --git a/Makefile b/Makefile index aa53152..311bd5e 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +.PHONY: run bench-select test + run: go run ./cmd -o ./playground/db ./playground/schema.sql bench-select: diff --git a/pgm_field.go b/pgm_field.go index a439629..f801f20 100644 --- a/pgm_field.go +++ b/pgm_field.go @@ -84,6 +84,43 @@ func (f Field) Desc() Field { return Field(f.String() + " DESC") } +func (f Field) RowNumber(as string) Field { + return rowNumber(&f, nil, true, as) +} + +func (f Field) RowNumberDesc(as string) Field { + return rowNumber(&f, nil, true, as) +} + +// RowNumberPartionBy in ascending order +func (f Field) RowNumberPartionBy(partition Field, as string) Field { + return rowNumber(&f, &partition, true, as) +} + +func (f Field) RowNumberDescPartionBy(partition Field, as string) Field { + return rowNumber(&f, &partition, false, as) +} + +func rowNumber(f, partition *Field, isAsc bool, as string) Field { + var orderBy string + if isAsc { + orderBy = " ASC" + } else { + orderBy = " DESC" + } + + if as == "" { + as = "row_number" + } + + col := f.String() + if partition != nil { + return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + col + orderBy + ") AS " + as) + } + + return Field("ROW_NUMBER() OVER (ORDER BY " + col + orderBy + ") AS " + as) +} + func (f Field) IsNull() Conditioner { col := f.String() return &Cond{Field: col, op: " IS NULL", len: len(col) + 8} diff --git a/pgm_table.go b/pgm_table.go index 2bf0d9f..2ccf8d0 100644 --- a/pgm_table.go +++ b/pgm_table.go @@ -2,10 +2,11 @@ package pgm // Table in database type Table struct { - Name string - PK []string - FieldCount uint16 - debug bool + Name string + DerivedTable Query + PK []string + FieldCount uint16 + debug bool } // Debug when set true will print generated query string in stdout @@ -14,6 +15,10 @@ func (t *Table) Debug() Clause { return t } +func (t *Table) Field(f string) Field { + return Field(t.Name + "." + f) +} + // Insert table statement func (t *Table) Insert() InsertClause { qb := &insertQry{ @@ -30,10 +35,17 @@ func (t *Table) Insert() InsertClause { func (t *Table) Select(field ...Field) SelectClause { qb := &selectQry{ debug: t.debug, - table: t.Name, fields: field, } + if t.DerivedTable != nil { + tName, args := t.DerivedTable.Build(true) + qb.table = "(" + tName + ") AS " + t.Name + qb.args = args + } else { + qb.table = t.Name + } + return qb } diff --git a/playground/db/derived.go b/playground/db/derived.go new file mode 100644 index 0000000..4fcbf1d --- /dev/null +++ b/playground/db/derived.go @@ -0,0 +1,11 @@ +package db + +import "code.patial.tech/go/pgm" + +func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table { + t := pgm.Table{ + Name: tblName, + DerivedTable: fromQry, + } + return t +} diff --git a/playground/qry_select_test.go b/playground/qry_select_test.go index 64d4a4b..0481d9d 100644 --- a/playground/qry_select_test.go +++ b/playground/qry_select_test.go @@ -60,6 +60,27 @@ func TestSelectWithHaving(t *testing.T) { } } +func TestSelectDerived(t *testing.T) { + expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" + + " FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" + + " ORDER BY t.status_id, t.created_at DESC" + + qry := db.User. + Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")). + Where(user.StatusID.Eq(1)) + + tbl := db.DerivedTable("t", qry) + got := tbl. + Select(tbl.Field("*")). + Where(tbl.Field("rn").Lte(5)). + OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()). + String() + + if expected != got { + t.Errorf("\nexpected: %q\n\ngot: %q", expected, got) + } +} + func TestSelectWithJoin(t *testing.T) { got := db.User.Select(user.Email, user.FirstName). Join(db.UserSession, user.ID, usersession.UserID). diff --git a/qry_select.go b/qry_select.go index a82ca7f..f2b83de 100644 --- a/qry_select.go +++ b/qry_select.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "slices" "strconv" "strings" @@ -103,6 +104,7 @@ type ( First All Stringer + Bulder } RowScanner interface { @@ -128,6 +130,10 @@ type ( AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error } + Bulder interface { + Build(needArgs bool) (qry string, args []any) + } + selectQry struct { table string fields []Field @@ -291,6 +297,17 @@ func (q *selectQry) raw(prefixArgs []any) (string, []any) { } func (q *selectQry) String() string { + qry, _ := q.Build(false) + if q.debug { + fmt.Println("***") + fmt.Println(qry) + fmt.Printf("%+v\n", q.args) + fmt.Println("***") + } + return qry +} + +func (q *selectQry) Build(needArgs bool) (qry string, args []any) { sb := getSB() defer putSB(sb) @@ -356,14 +373,19 @@ func (q *selectQry) String() string { sb.WriteString(strconv.Itoa(q.offset)) } - qry := sb.String() + qry = sb.String() if q.debug { fmt.Println("***") fmt.Println(qry) fmt.Printf("%+v\n", q.args) fmt.Println("***") } - return qry + + if needArgs { + args = slices.Clone(q.args) + } + + return } func (q *selectQry) averageLen() int {