revamp, returing will include the table name.cloumn_name. feat, ts_rank, ts_search helpers

This commit is contained in:
2025-11-02 22:04:02 +05:30
parent bb6a45732f
commit 335046658e
9 changed files with 92 additions and 35 deletions

View File

@@ -2,8 +2,9 @@
run:
go run ./cmd -o ./playground/db ./playground/schema.sql
bench-select:
go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
test:
go test ./playground

View File

@@ -18,6 +18,18 @@ func (f Field) Count() Field {
return Field("COUNT(" + f.String() + ")")
}
func ConcatWs(sep string, fields ...Field) Field {
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
}
func StringAgg(exp, sep string) Field {
return Field("string_agg(" + exp + ",'" + sep + "')")
}
func StringAggCast(exp, sep string) Field {
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
}
// StringEscape will wrap field with:
//
// COALESCE(field, ”)
@@ -141,6 +153,10 @@ func (f Field) DateTrunc(level, as string) Field {
return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as)
}
func (f Field) TsRank(query, as string) Field {
return Field("TS_RANK(" + f.String() + ", " + query + ") AS " + as)
}
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
func (f Field) EqFold(val string) Conditioner {
col := f.String()
@@ -210,16 +226,9 @@ func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
}
func ConcatWs(sep string, fields ...Field) Field {
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
}
func StringAgg(exp, sep string) Field {
return Field("string_agg(" + exp + ",'" + sep + "')")
}
func StringAggCast(exp, sep string) Field {
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
func (f Field) TsQuery(as string) Conditioner {
col := f.String()
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
}
func joinFileds(fields []Field) string {

View File

@@ -1,7 +1,14 @@
package pgm
import (
"strconv"
)
// Table in database
type Table struct {
tsQuery *string
tsQueryAs *string
Name string
DerivedTable Query
PK []string
@@ -31,6 +38,12 @@ func (t *Table) Insert() InsertClause {
return qb
}
func (t *Table) WithTsQuery(q, as string) *Table {
t.tsQuery = &q
t.tsQueryAs = &as
return t
}
// Select table statement
func (t *Table) Select(field ...Field) SelectClause {
qb := &selectQry{
@@ -46,6 +59,18 @@ func (t *Table) Select(field ...Field) SelectClause {
qb.table = t.Name
}
if t.tsQuery != nil {
var as string
if t.tsQueryAs != nil && *t.tsQueryAs != "" {
as = *t.tsQueryAs
} else {
// add default as field
as = "query"
}
qb.args = append(qb.args, as)
qb.table += ", TO_TSQUERY('english', $" + strconv.Itoa(len(qb.args)) + ") " + as
}
return qb
}

View File

@@ -5,7 +5,7 @@ package db
import "code.patial.tech/go/pgm"
var (
User = pgm.Table{Name: "users", FieldCount: 11}
User = pgm.Table{Name: "users", FieldCount: 12}
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
Post = pgm.Table{Name: "posts", FieldCount: 5}

View File

@@ -25,6 +25,8 @@ const (
StatusID pgm.Field = "users.status_id"
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
MfaKind pgm.Field = "users.mfa_kind"
// SearchVector field has db type "tsvector"
SearchVector pgm.Field = "users.search_vector"
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
CreatedAt pgm.Field = "users.created_at"
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"

View File

@@ -17,7 +17,7 @@ func TestInsertQuery(t *testing.T) {
Returning(user.ID).
String()
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id"
if got != expected {
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
}

View File

@@ -60,27 +60,6 @@ func TestSelectWithHaving(t *testing.T) {
}
}
func TestSelectDerived(t *testing.T) {
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
" ORDER BY t.status_id, t.created_at DESC"
qry := db.User.
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
Where(user.StatusID.Eq(1))
tbl := db.DerivedTable("t", qry)
got := tbl.
Select(tbl.Field("*")).
Where(tbl.Field("rn").Lte(5)).
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
String()
if expected != got {
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
}
}
func TestSelectWithJoin(t *testing.T) {
got := db.User.Select(user.Email, user.FirstName).
Join(db.UserSession, user.ID, usersession.UserID).
@@ -106,6 +85,46 @@ func TestSelectWithJoin(t *testing.T) {
}
}
func TestSelectDerived(t *testing.T) {
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
" ORDER BY t.status_id, t.created_at DESC"
qry := db.User.
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
Where(user.StatusID.Eq(1))
tbl := db.DerivedTable("t", qry)
got := tbl.
Select(tbl.Field("*")).
Where(tbl.Field("rn").Lte(5)).
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
String()
if expected != got {
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
}
}
func TestSelectTV(t *testing.T) {
expected := "SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, query) AS rank" +
" FROM users, TO_TSQUERY('english', $1) query" +
" WHERE users.status_id = $2 AND users.search_vector @@ query" +
" ORDER BY rank DESC"
qry := db.User.
WithTsQuery("anki", "query").
Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("query", "rank")).
Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("query")).
OrderBy(pgm.Field("rank").Desc())
got := qry.String()
if expected != got {
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
}
}
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
func BenchmarkSelect(b *testing.B) {
for b.Loop() {

View File

@@ -61,6 +61,7 @@ CREATE TABLE public.users (
last_name character varying(50) NOT NULL,
status_id smallint,
mfa_kind character varying(50) DEFAULT 'None'::character varying,
search_vector tsvector,
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -55,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause {
}
func (q *insertQry) Returning(field Field) First {
col := field.Name()
col := field.String()
q.returing = &col
return q
}