Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a795c0e8d6 | |||
| bb6a45732f | |||
| 2551e07b3e |
3
Makefile
3
Makefile
@@ -2,8 +2,9 @@
|
|||||||
|
|
||||||
run:
|
run:
|
||||||
go run ./cmd -o ./playground/db ./playground/schema.sql
|
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||||
|
|
||||||
bench-select:
|
bench-select:
|
||||||
go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test ./playground
|
go test ./playground
|
||||||
|
|||||||
@@ -76,8 +76,18 @@ func generate(scheamPath, outDir string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(")")
|
sb.WriteString(")")
|
||||||
|
|
||||||
|
sb.WriteString(`
|
||||||
|
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||||
|
t := pgm.Table{
|
||||||
|
Name: tblName,
|
||||||
|
DerivedTable: fromQry,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}`)
|
||||||
|
|
||||||
// Format code before saving
|
// Format code before saving
|
||||||
code, err := formatGoCode(sb.String())
|
code, err := formatGoCode(sb.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
39
pgm_field.go
39
pgm_field.go
@@ -18,6 +18,18 @@ func (f Field) Count() Field {
|
|||||||
return Field("COUNT(" + f.String() + ")")
|
return Field("COUNT(" + f.String() + ")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConcatWs(sep string, fields ...Field) Field {
|
||||||
|
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringAgg(exp, sep string) Field {
|
||||||
|
return Field("string_agg(" + exp + ",'" + sep + "')")
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringAggCast(exp, sep string) Field {
|
||||||
|
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
||||||
|
}
|
||||||
|
|
||||||
// StringEscape will wrap field with:
|
// StringEscape will wrap field with:
|
||||||
//
|
//
|
||||||
// COALESCE(field, ”)
|
// COALESCE(field, ”)
|
||||||
@@ -131,6 +143,20 @@ func (f Field) IsNotNull() Conditioner {
|
|||||||
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DateTrunc will truncate date or timestamp to specified level of precision
|
||||||
|
//
|
||||||
|
// Level values:
|
||||||
|
// - microseconds, milliseconds, second, minute, hour
|
||||||
|
// - day, week (Monday start), month, quarter, year
|
||||||
|
// - decade, century, millennium
|
||||||
|
func (f Field) DateTrunc(level, as string) Field {
|
||||||
|
return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) TsRank(query, as string) Field {
|
||||||
|
return Field("TS_RANK(" + f.String() + ", " + query + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
||||||
func (f Field) EqFold(val string) Conditioner {
|
func (f Field) EqFold(val string) Conditioner {
|
||||||
col := f.String()
|
col := f.String()
|
||||||
@@ -200,16 +226,9 @@ func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
|||||||
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConcatWs(sep string, fields ...Field) Field {
|
func (f Field) TsQuery(as string) Conditioner {
|
||||||
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
col := f.String()
|
||||||
}
|
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
|
||||||
|
|
||||||
func StringAgg(exp, sep string) Field {
|
|
||||||
return Field("string_agg(" + exp + ",'" + sep + "')")
|
|
||||||
}
|
|
||||||
|
|
||||||
func StringAggCast(exp, sep string) Field {
|
|
||||||
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func joinFileds(fields []Field) string {
|
func joinFileds(fields []Field) string {
|
||||||
|
|||||||
25
pgm_table.go
25
pgm_table.go
@@ -1,7 +1,14 @@
|
|||||||
package pgm
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
// Table in database
|
// Table in database
|
||||||
type Table struct {
|
type Table struct {
|
||||||
|
tsQuery *string
|
||||||
|
tsQueryAs *string
|
||||||
|
|
||||||
Name string
|
Name string
|
||||||
DerivedTable Query
|
DerivedTable Query
|
||||||
PK []string
|
PK []string
|
||||||
@@ -31,6 +38,12 @@ func (t *Table) Insert() InsertClause {
|
|||||||
return qb
|
return qb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Table) WithTsQuery(q, as string) *Table {
|
||||||
|
t.tsQuery = &q
|
||||||
|
t.tsQueryAs = &as
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
// Select table statement
|
// Select table statement
|
||||||
func (t *Table) Select(field ...Field) SelectClause {
|
func (t *Table) Select(field ...Field) SelectClause {
|
||||||
qb := &selectQry{
|
qb := &selectQry{
|
||||||
@@ -46,6 +59,18 @@ func (t *Table) Select(field ...Field) SelectClause {
|
|||||||
qb.table = t.Name
|
qb.table = t.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.tsQuery != nil {
|
||||||
|
var as string
|
||||||
|
if t.tsQueryAs != nil && *t.tsQueryAs != "" {
|
||||||
|
as = *t.tsQueryAs
|
||||||
|
} else {
|
||||||
|
// add default as field
|
||||||
|
as = "query"
|
||||||
|
}
|
||||||
|
qb.args = append(qb.args, as)
|
||||||
|
qb.table += ", TO_TSQUERY('english', $" + strconv.Itoa(len(qb.args)) + ") " + as
|
||||||
|
}
|
||||||
|
|
||||||
return qb
|
return qb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
|
||||||
|
|
||||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
|
||||||
t := pgm.Table{
|
|
||||||
Name: tblName,
|
|
||||||
DerivedTable: fromQry,
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
@@ -5,10 +5,18 @@ package db
|
|||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
User = pgm.Table{Name: "users", FieldCount: 11}
|
User = pgm.Table{Name: "users", FieldCount: 12}
|
||||||
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
||||||
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
||||||
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
||||||
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
||||||
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||||
|
t := pgm.Table{
|
||||||
|
Name: tblName,
|
||||||
|
DerivedTable: fromQry,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ const (
|
|||||||
StatusID pgm.Field = "users.status_id"
|
StatusID pgm.Field = "users.status_id"
|
||||||
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
||||||
MfaKind pgm.Field = "users.mfa_kind"
|
MfaKind pgm.Field = "users.mfa_kind"
|
||||||
|
// SearchVector field has db type "tsvector"
|
||||||
|
SearchVector pgm.Field = "users.search_vector"
|
||||||
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||||
CreatedAt pgm.Field = "users.created_at"
|
CreatedAt pgm.Field = "users.created_at"
|
||||||
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func TestInsertQuery(t *testing.T) {
|
|||||||
Returning(user.ID).
|
Returning(user.ID).
|
||||||
String()
|
String()
|
||||||
|
|
||||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id"
|
||||||
if got != expected {
|
if got != expected {
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,27 +60,6 @@ func TestSelectWithHaving(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectDerived(t *testing.T) {
|
|
||||||
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
|
||||||
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
|
||||||
" ORDER BY t.status_id, t.created_at DESC"
|
|
||||||
|
|
||||||
qry := db.User.
|
|
||||||
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
|
||||||
Where(user.StatusID.Eq(1))
|
|
||||||
|
|
||||||
tbl := db.DerivedTable("t", qry)
|
|
||||||
got := tbl.
|
|
||||||
Select(tbl.Field("*")).
|
|
||||||
Where(tbl.Field("rn").Lte(5)).
|
|
||||||
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
|
||||||
String()
|
|
||||||
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSelectWithJoin(t *testing.T) {
|
func TestSelectWithJoin(t *testing.T) {
|
||||||
got := db.User.Select(user.Email, user.FirstName).
|
got := db.User.Select(user.Email, user.FirstName).
|
||||||
Join(db.UserSession, user.ID, usersession.UserID).
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
@@ -106,6 +85,46 @@ func TestSelectWithJoin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelectDerived(t *testing.T) {
|
||||||
|
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||||
|
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||||
|
" ORDER BY t.status_id, t.created_at DESC"
|
||||||
|
|
||||||
|
qry := db.User.
|
||||||
|
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||||
|
Where(user.StatusID.Eq(1))
|
||||||
|
|
||||||
|
tbl := db.DerivedTable("t", qry)
|
||||||
|
got := tbl.
|
||||||
|
Select(tbl.Field("*")).
|
||||||
|
Where(tbl.Field("rn").Lte(5)).
|
||||||
|
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||||
|
String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectTV(t *testing.T) {
|
||||||
|
expected := "SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, query) AS rank" +
|
||||||
|
" FROM users, TO_TSQUERY('english', $1) query" +
|
||||||
|
" WHERE users.status_id = $2 AND users.search_vector @@ query" +
|
||||||
|
" ORDER BY rank DESC"
|
||||||
|
|
||||||
|
qry := db.User.
|
||||||
|
WithTsQuery("anki", "query").
|
||||||
|
Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("query", "rank")).
|
||||||
|
Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("query")).
|
||||||
|
OrderBy(pgm.Field("rank").Desc())
|
||||||
|
|
||||||
|
got := qry.String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
||||||
func BenchmarkSelect(b *testing.B) {
|
func BenchmarkSelect(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ CREATE TABLE public.users (
|
|||||||
last_name character varying(50) NOT NULL,
|
last_name character varying(50) NOT NULL,
|
||||||
status_id smallint,
|
status_id smallint,
|
||||||
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
||||||
|
search_vector tsvector,
|
||||||
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) Returning(field Field) First {
|
func (q *insertQry) Returning(field Field) First {
|
||||||
col := field.Name()
|
col := field.String()
|
||||||
q.returing = &col
|
q.returing = &col
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user