diff --git a/Makefile b/Makefile index 311bd5e..36e6d4f 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,9 @@ run: go run ./cmd -o ./playground/db ./playground/schema.sql + bench-select: - go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out + go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out test: go test ./playground diff --git a/pgm_field.go b/pgm_field.go index 3402320..2dba56b 100644 --- a/pgm_field.go +++ b/pgm_field.go @@ -18,6 +18,18 @@ func (f Field) Count() Field { return Field("COUNT(" + f.String() + ")") } +func ConcatWs(sep string, fields ...Field) Field { + return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")") +} + +func StringAgg(exp, sep string) Field { + return Field("string_agg(" + exp + ",'" + sep + "')") +} + +func StringAggCast(exp, sep string) Field { + return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')") +} + // StringEscape will wrap field with: // // COALESCE(field, ”) @@ -141,6 +153,10 @@ func (f Field) DateTrunc(level, as string) Field { return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as) } +func (f Field) TsRank(query, as string) Field { + return Field("TS_RANK(" + f.String() + ", " + query + ") AS " + as) +} + // EqualFold will use LOWER(column_name) = LOWER(val) for comparision func (f Field) EqFold(val string) Conditioner { col := f.String() @@ -210,16 +226,9 @@ func (f Field) NotInSubQuery(qry WhereClause) Conditioner { return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery} } -func ConcatWs(sep string, fields ...Field) Field { - return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")") -} - -func StringAgg(exp, sep string) Field { - return Field("string_agg(" + exp + ",'" + sep + "')") -} - -func StringAggCast(exp, sep string) Field { - return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')") +func (f Field) TsQuery(as string) Conditioner { + col := f.String() + return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5} } func joinFileds(fields []Field) string { diff --git a/pgm_table.go b/pgm_table.go index 2ccf8d0..9d72d3b 100644 --- a/pgm_table.go +++ b/pgm_table.go @@ -1,7 +1,14 @@ package pgm +import ( + "strconv" +) + // Table in database type Table struct { + tsQuery *string + tsQueryAs *string + Name string DerivedTable Query PK []string @@ -31,6 +38,12 @@ func (t *Table) Insert() InsertClause { return qb } +func (t *Table) WithTsQuery(q, as string) *Table { + t.tsQuery = &q + t.tsQueryAs = &as + return t +} + // Select table statement func (t *Table) Select(field ...Field) SelectClause { qb := &selectQry{ @@ -46,6 +59,18 @@ func (t *Table) Select(field ...Field) SelectClause { qb.table = t.Name } + if t.tsQuery != nil { + var as string + if t.tsQueryAs != nil && *t.tsQueryAs != "" { + as = *t.tsQueryAs + } else { + // add default as field + as = "query" + } + qb.args = append(qb.args, as) + qb.table += ", TO_TSQUERY('english', $" + strconv.Itoa(len(qb.args)) + ") " + as + } + return qb } diff --git a/playground/db/schema.go b/playground/db/schema.go index 9446337..1458c80 100644 --- a/playground/db/schema.go +++ b/playground/db/schema.go @@ -5,7 +5,7 @@ package db import "code.patial.tech/go/pgm" var ( - User = pgm.Table{Name: "users", FieldCount: 11} + User = pgm.Table{Name: "users", FieldCount: 12} UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8} BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5} Post = pgm.Table{Name: "posts", FieldCount: 5} diff --git a/playground/db/user/users.go b/playground/db/user/users.go index cbd86a0..ef87ea1 100644 --- a/playground/db/user/users.go +++ b/playground/db/user/users.go @@ -25,6 +25,8 @@ const ( StatusID pgm.Field = "users.status_id" // MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying" MfaKind pgm.Field = "users.mfa_kind" + // SearchVector field has db type "tsvector" + SearchVector pgm.Field = "users.search_vector" // CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP" CreatedAt pgm.Field = "users.created_at" // UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP" diff --git a/playground/qry_insert_test.go b/playground/qry_insert_test.go index 9277c5b..6564ea5 100644 --- a/playground/qry_insert_test.go +++ b/playground/qry_insert_test.go @@ -17,7 +17,7 @@ func TestInsertQuery(t *testing.T) { Returning(user.ID). String() - expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id" + expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id" if got != expected { t.Errorf("\nexpected: %q\ngot: %q", expected, got) } diff --git a/playground/qry_select_test.go b/playground/qry_select_test.go index 0481d9d..c30ba37 100644 --- a/playground/qry_select_test.go +++ b/playground/qry_select_test.go @@ -60,27 +60,6 @@ func TestSelectWithHaving(t *testing.T) { } } -func TestSelectDerived(t *testing.T) { - expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" + - " FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" + - " ORDER BY t.status_id, t.created_at DESC" - - qry := db.User. - Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")). - Where(user.StatusID.Eq(1)) - - tbl := db.DerivedTable("t", qry) - got := tbl. - Select(tbl.Field("*")). - Where(tbl.Field("rn").Lte(5)). - OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()). - String() - - if expected != got { - t.Errorf("\nexpected: %q\n\ngot: %q", expected, got) - } -} - func TestSelectWithJoin(t *testing.T) { got := db.User.Select(user.Email, user.FirstName). Join(db.UserSession, user.ID, usersession.UserID). @@ -106,6 +85,46 @@ func TestSelectWithJoin(t *testing.T) { } } +func TestSelectDerived(t *testing.T) { + expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" + + " FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" + + " ORDER BY t.status_id, t.created_at DESC" + + qry := db.User. + Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")). + Where(user.StatusID.Eq(1)) + + tbl := db.DerivedTable("t", qry) + got := tbl. + Select(tbl.Field("*")). + Where(tbl.Field("rn").Lte(5)). + OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()). + String() + + if expected != got { + t.Errorf("\nexpected: %q\n\ngot: %q", expected, got) + } +} + +func TestSelectTV(t *testing.T) { + expected := "SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, query) AS rank" + + " FROM users, TO_TSQUERY('english', $1) query" + + " WHERE users.status_id = $2 AND users.search_vector @@ query" + + " ORDER BY rank DESC" + + qry := db.User. + WithTsQuery("anki", "query"). + Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("query", "rank")). + Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("query")). + OrderBy(pgm.Field("rank").Desc()) + + got := qry.String() + + if expected != got { + t.Errorf("\nexpected: %q\n\ngot: %q", expected, got) + } +} + // BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op func BenchmarkSelect(b *testing.B) { for b.Loop() { diff --git a/playground/schema.sql b/playground/schema.sql index 8e2f9fd..52421e0 100644 --- a/playground/schema.sql +++ b/playground/schema.sql @@ -61,6 +61,7 @@ CREATE TABLE public.users ( last_name character varying(50) NOT NULL, status_id smallint, mfa_kind character varying(50) DEFAULT 'None'::character varying, + search_vector tsvector, created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP ); diff --git a/qry_insert.go b/qry_insert.go index aa871b0..3490920 100644 --- a/qry_insert.go +++ b/qry_insert.go @@ -55,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause { } func (q *insertQry) Returning(field Field) First { - col := field.Name() + col := field.String() q.returing = &col return q }