derived table and row number addition
This commit is contained in:
2
Makefile
2
Makefile
@@ -1,3 +1,5 @@
|
||||
.PHONY: run bench-select test
|
||||
|
||||
run:
|
||||
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||
bench-select:
|
||||
|
||||
37
pgm_field.go
37
pgm_field.go
@@ -84,6 +84,43 @@ func (f Field) Desc() Field {
|
||||
return Field(f.String() + " DESC")
|
||||
}
|
||||
|
||||
func (f Field) RowNumber(as string) Field {
|
||||
return rowNumber(&f, nil, true, as)
|
||||
}
|
||||
|
||||
func (f Field) RowNumberDesc(as string) Field {
|
||||
return rowNumber(&f, nil, true, as)
|
||||
}
|
||||
|
||||
// RowNumberPartionBy in ascending order
|
||||
func (f Field) RowNumberPartionBy(partition Field, as string) Field {
|
||||
return rowNumber(&f, &partition, true, as)
|
||||
}
|
||||
|
||||
func (f Field) RowNumberDescPartionBy(partition Field, as string) Field {
|
||||
return rowNumber(&f, &partition, false, as)
|
||||
}
|
||||
|
||||
func rowNumber(f, partition *Field, isAsc bool, as string) Field {
|
||||
var orderBy string
|
||||
if isAsc {
|
||||
orderBy = " ASC"
|
||||
} else {
|
||||
orderBy = " DESC"
|
||||
}
|
||||
|
||||
if as == "" {
|
||||
as = "row_number"
|
||||
}
|
||||
|
||||
col := f.String()
|
||||
if partition != nil {
|
||||
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + col + orderBy + ") AS " + as)
|
||||
}
|
||||
|
||||
return Field("ROW_NUMBER() OVER (ORDER BY " + col + orderBy + ") AS " + as)
|
||||
}
|
||||
|
||||
func (f Field) IsNull() Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
|
||||
|
||||
22
pgm_table.go
22
pgm_table.go
@@ -2,10 +2,11 @@ package pgm
|
||||
|
||||
// Table in database
|
||||
type Table struct {
|
||||
Name string
|
||||
PK []string
|
||||
FieldCount uint16
|
||||
debug bool
|
||||
Name string
|
||||
DerivedTable Query
|
||||
PK []string
|
||||
FieldCount uint16
|
||||
debug bool
|
||||
}
|
||||
|
||||
// Debug when set true will print generated query string in stdout
|
||||
@@ -14,6 +15,10 @@ func (t *Table) Debug() Clause {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Table) Field(f string) Field {
|
||||
return Field(t.Name + "." + f)
|
||||
}
|
||||
|
||||
// Insert table statement
|
||||
func (t *Table) Insert() InsertClause {
|
||||
qb := &insertQry{
|
||||
@@ -30,10 +35,17 @@ func (t *Table) Insert() InsertClause {
|
||||
func (t *Table) Select(field ...Field) SelectClause {
|
||||
qb := &selectQry{
|
||||
debug: t.debug,
|
||||
table: t.Name,
|
||||
fields: field,
|
||||
}
|
||||
|
||||
if t.DerivedTable != nil {
|
||||
tName, args := t.DerivedTable.Build(true)
|
||||
qb.table = "(" + tName + ") AS " + t.Name
|
||||
qb.args = args
|
||||
} else {
|
||||
qb.table = t.Name
|
||||
}
|
||||
|
||||
return qb
|
||||
}
|
||||
|
||||
|
||||
11
playground/db/derived.go
Normal file
11
playground/db/derived.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package db
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||
t := pgm.Table{
|
||||
Name: tblName,
|
||||
DerivedTable: fromQry,
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -60,6 +60,27 @@ func TestSelectWithHaving(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectDerived(t *testing.T) {
|
||||
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||
" ORDER BY t.status_id, t.created_at DESC"
|
||||
|
||||
qry := db.User.
|
||||
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||
Where(user.StatusID.Eq(1))
|
||||
|
||||
tbl := db.DerivedTable("t", qry)
|
||||
got := tbl.
|
||||
Select(tbl.Field("*")).
|
||||
Where(tbl.Field("rn").Lte(5)).
|
||||
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||
String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithJoin(t *testing.T) {
|
||||
got := db.User.Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -103,6 +104,7 @@ type (
|
||||
First
|
||||
All
|
||||
Stringer
|
||||
Bulder
|
||||
}
|
||||
|
||||
RowScanner interface {
|
||||
@@ -128,6 +130,10 @@ type (
|
||||
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
||||
}
|
||||
|
||||
Bulder interface {
|
||||
Build(needArgs bool) (qry string, args []any)
|
||||
}
|
||||
|
||||
selectQry struct {
|
||||
table string
|
||||
fields []Field
|
||||
@@ -291,6 +297,17 @@ func (q *selectQry) raw(prefixArgs []any) (string, []any) {
|
||||
}
|
||||
|
||||
func (q *selectQry) String() string {
|
||||
qry, _ := q.Build(false)
|
||||
if q.debug {
|
||||
fmt.Println("***")
|
||||
fmt.Println(qry)
|
||||
fmt.Printf("%+v\n", q.args)
|
||||
fmt.Println("***")
|
||||
}
|
||||
return qry
|
||||
}
|
||||
|
||||
func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
|
||||
@@ -356,14 +373,19 @@ func (q *selectQry) String() string {
|
||||
sb.WriteString(strconv.Itoa(q.offset))
|
||||
}
|
||||
|
||||
qry := sb.String()
|
||||
qry = sb.String()
|
||||
if q.debug {
|
||||
fmt.Println("***")
|
||||
fmt.Println(qry)
|
||||
fmt.Printf("%+v\n", q.args)
|
||||
fmt.Println("***")
|
||||
}
|
||||
return qry
|
||||
|
||||
if needArgs {
|
||||
args = slices.Clone(q.args)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (q *selectQry) averageLen() int {
|
||||
|
||||
Reference in New Issue
Block a user