derived table and row number addition

This commit is contained in:
2025-10-21 00:27:00 +05:30
parent 12d6fface6
commit 9837fb1e37
6 changed files with 112 additions and 7 deletions

View File

@@ -1,3 +1,5 @@
.PHONY: run bench-select test
run: run:
go run ./cmd -o ./playground/db ./playground/schema.sql go run ./cmd -o ./playground/db ./playground/schema.sql
bench-select: bench-select:

View File

@@ -84,6 +84,43 @@ func (f Field) Desc() Field {
return Field(f.String() + " DESC") return Field(f.String() + " DESC")
} }
func (f Field) RowNumber(as string) Field {
return rowNumber(&f, nil, true, as)
}
func (f Field) RowNumberDesc(as string) Field {
return rowNumber(&f, nil, true, as)
}
// RowNumberPartionBy in ascending order
func (f Field) RowNumberPartionBy(partition Field, as string) Field {
return rowNumber(&f, &partition, true, as)
}
func (f Field) RowNumberDescPartionBy(partition Field, as string) Field {
return rowNumber(&f, &partition, false, as)
}
func rowNumber(f, partition *Field, isAsc bool, as string) Field {
var orderBy string
if isAsc {
orderBy = " ASC"
} else {
orderBy = " DESC"
}
if as == "" {
as = "row_number"
}
col := f.String()
if partition != nil {
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + col + orderBy + ") AS " + as)
}
return Field("ROW_NUMBER() OVER (ORDER BY " + col + orderBy + ") AS " + as)
}
func (f Field) IsNull() Conditioner { func (f Field) IsNull() Conditioner {
col := f.String() col := f.String()
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8} return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}

View File

@@ -2,10 +2,11 @@ package pgm
// Table in database // Table in database
type Table struct { type Table struct {
Name string Name string
PK []string DerivedTable Query
FieldCount uint16 PK []string
debug bool FieldCount uint16
debug bool
} }
// Debug when set true will print generated query string in stdout // Debug when set true will print generated query string in stdout
@@ -14,6 +15,10 @@ func (t *Table) Debug() Clause {
return t return t
} }
func (t *Table) Field(f string) Field {
return Field(t.Name + "." + f)
}
// Insert table statement // Insert table statement
func (t *Table) Insert() InsertClause { func (t *Table) Insert() InsertClause {
qb := &insertQry{ qb := &insertQry{
@@ -30,10 +35,17 @@ func (t *Table) Insert() InsertClause {
func (t *Table) Select(field ...Field) SelectClause { func (t *Table) Select(field ...Field) SelectClause {
qb := &selectQry{ qb := &selectQry{
debug: t.debug, debug: t.debug,
table: t.Name,
fields: field, fields: field,
} }
if t.DerivedTable != nil {
tName, args := t.DerivedTable.Build(true)
qb.table = "(" + tName + ") AS " + t.Name
qb.args = args
} else {
qb.table = t.Name
}
return qb return qb
} }

11
playground/db/derived.go Normal file
View File

@@ -0,0 +1,11 @@
package db
import "code.patial.tech/go/pgm"
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
t := pgm.Table{
Name: tblName,
DerivedTable: fromQry,
}
return t
}

View File

@@ -60,6 +60,27 @@ func TestSelectWithHaving(t *testing.T) {
} }
} }
func TestSelectDerived(t *testing.T) {
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
" ORDER BY t.status_id, t.created_at DESC"
qry := db.User.
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
Where(user.StatusID.Eq(1))
tbl := db.DerivedTable("t", qry)
got := tbl.
Select(tbl.Field("*")).
Where(tbl.Field("rn").Lte(5)).
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
String()
if expected != got {
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
}
}
func TestSelectWithJoin(t *testing.T) { func TestSelectWithJoin(t *testing.T) {
got := db.User.Select(user.Email, user.FirstName). got := db.User.Select(user.Email, user.FirstName).
Join(db.UserSession, user.ID, usersession.UserID). Join(db.UserSession, user.ID, usersession.UserID).

View File

@@ -7,6 +7,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"slices"
"strconv" "strconv"
"strings" "strings"
@@ -103,6 +104,7 @@ type (
First First
All All
Stringer Stringer
Bulder
} }
RowScanner interface { RowScanner interface {
@@ -128,6 +130,10 @@ type (
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
} }
Bulder interface {
Build(needArgs bool) (qry string, args []any)
}
selectQry struct { selectQry struct {
table string table string
fields []Field fields []Field
@@ -291,6 +297,17 @@ func (q *selectQry) raw(prefixArgs []any) (string, []any) {
} }
func (q *selectQry) String() string { func (q *selectQry) String() string {
qry, _ := q.Build(false)
if q.debug {
fmt.Println("***")
fmt.Println(qry)
fmt.Printf("%+v\n", q.args)
fmt.Println("***")
}
return qry
}
func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
sb := getSB() sb := getSB()
defer putSB(sb) defer putSB(sb)
@@ -356,14 +373,19 @@ func (q *selectQry) String() string {
sb.WriteString(strconv.Itoa(q.offset)) sb.WriteString(strconv.Itoa(q.offset))
} }
qry := sb.String() qry = sb.String()
if q.debug { if q.debug {
fmt.Println("***") fmt.Println("***")
fmt.Println(qry) fmt.Println(qry)
fmt.Printf("%+v\n", q.args) fmt.Printf("%+v\n", q.args)
fmt.Println("***") fmt.Println("***")
} }
return qry
if needArgs {
args = slices.Clone(q.args)
}
return
} }
func (q *selectQry) averageLen() int { func (q *selectQry) averageLen() int {