// Patial Tech. // Author, Ankit Patial package pgm import ( "context" "errors" "fmt" "strconv" "strings" "github.com/jackc/pgx/v5" ) type ( SelectClause interface { // Join and Inner Join are same Join(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause LeftJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause RightJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause FullJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause CrossJoin(m Table) SelectClause WhereClause OrderByClause GroupByClause LimitClause OffsetClause Query raw(prefixArgs []any) (string, []any) } WhereClause interface { Where(cond ...Conditioner) AfterWhere } AfterWhere interface { WhereClause GroupByClause OrderByClause LimitClause OffsetClause Query } GroupByClause interface { GroupBy(fields ...Field) AfterGroupBy } AfterGroupBy interface { HavinClause OrderByClause LimitClause OffsetClause Query } HavinClause interface { Having(cond ...Conditioner) AfterHaving } AfterHaving interface { OrderByClause LimitClause OffsetClause Query } OrderByClause interface { OrderBy(fields ...Field) AfterOrderBy } AfterOrderBy interface { LimitClause OffsetClause Query } LimitClause interface { Limit(v int) AfterLimit } AfterLimit interface { OffsetClause Query } OffsetClause interface { Offset(v int) AfterOffset } AfterOffset interface { LimitClause Query } Do interface { DoNothing() Execute DoUpdate(fields ...Field) Execute } Query interface { First All Stringer } RowScanner interface { Scan(dest ...any) error } RowsCb func(row RowScanner) error First interface { First(ctx context.Context, dest ...any) error FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error Stringer } All interface { // Query rows // // don't forget to close() rows All(ctx context.Context, rows RowsCb) error // Query rows // // don't forget to close() rows AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error } selectQry struct { table string fields []Field args []any join []string where []Conditioner groupBy []Field having []Conditioner orderBy []Field limit int offset int debug bool } CondAction uint8 Cond struct { Val any op string Field string len int action CondAction } CondGroup struct { op string cond []Conditioner } ) // Contdition actions const ( CondActionNothing CondAction = iota CondActionNeedToClose CondActionSubQuery ) func (q *selectQry) Join(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "JOIN", t1Field, t2Field, cond...) } func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "LEFT JOIN", t1Field, t2Field, cond...) } func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "RIGHT JOIN", t1Field, t2Field, cond...) } func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "FULL JOIN", t1Field, t2Field, cond...) } func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, cond ...Conditioner) SelectClause { str := joinKW + " " + t.Name + " ON " + t1Field.String() + " = " + t2Field.String() if len(cond) == 0 { // Join with no condition q.join = append(q.join, str) return q } // Join has condition(s) sb := getSB() defer putSB(sb) sb.Grow(len(str) * 2) sb.WriteString(str + " AND ") var argIdx int for i, c := range cond { argIdx = len(q.args) if i > 0 { sb.WriteString(" AND ") } sb.WriteString(c.Condition(&q.args, argIdx)) } q.join = append(q.join, sb.String()) return q } func (q *selectQry) CrossJoin(t Table) SelectClause { q.join = append(q.join, "CROSS JOIN "+t.Name) return q } func (q *selectQry) Where(cond ...Conditioner) AfterWhere { q.where = append(q.where, cond...) return q } func (q *selectQry) OrderBy(fields ...Field) AfterOrderBy { q.orderBy = fields return q } func (q *selectQry) GroupBy(fields ...Field) AfterGroupBy { q.groupBy = fields return q } func (q *selectQry) Having(cond ...Conditioner) AfterHaving { q.having = append(q.having, cond...) return q } func (q *selectQry) Limit(v int) AfterLimit { q.limit = v return q } func (q *selectQry) Offset(v int) AfterOffset { q.offset = v return q } func (q *selectQry) First(ctx context.Context, dest ...any) error { return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...) } func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error { return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...) } func (q *selectQry) All(ctx context.Context, row RowsCb) error { rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...) if errors.Is(err, pgx.ErrNoRows) { return ErrNoRows } defer rows.Close() for rows.Next() { if err := row(rows); err != nil { return err } } return nil } func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error { rows, err := tx.Query(ctx, q.String(), q.args...) if errors.Is(err, pgx.ErrNoRows) { return ErrNoRows } defer rows.Close() for rows.Next() { if err := row(rows); err != nil { return err } } return nil } func (q *selectQry) raw(prefixArgs []any) (string, []any) { q.args = append(prefixArgs, q.args...) return q.String(), q.args } func (q *selectQry) String() string { sb := getSB() defer putSB(sb) sb.Grow(q.averageLen()) // SELECT sb.WriteString("SELECT ") sb.WriteString(joinFileds(q.fields)) sb.WriteString(" FROM " + q.table) // JOIN if len(q.join) > 0 { sb.WriteString(" " + strings.Join(q.join, " ")) } // WHERE if len(q.where) > 0 { sb.WriteString(" WHERE ") var argIdx int for i, c := range q.where { argIdx = len(q.args) if i > 0 { sb.WriteString(" AND ") } sb.WriteString(c.Condition(&q.args, argIdx)) } } // GROUP BY if len(q.groupBy) > 0 { sb.WriteString(" GROUP BY ") sb.WriteString(joinFileds(q.groupBy)) } // HAVING if len(q.having) > 0 { sb.WriteString(" HAVING ") var argIdx int for i, c := range q.having { argIdx = len(q.args) if i > 0 { sb.WriteString(" AND ") } sb.WriteString(c.Condition(&q.args, argIdx)) } } // ORDER BY if len(q.orderBy) > 0 { sb.WriteString(" ORDER BY ") sb.WriteString(joinFileds(q.orderBy)) } // LIMIT if q.limit > 0 { sb.WriteString(" LIMIT ") sb.WriteString(strconv.Itoa(q.limit)) } // OFFSET if q.offset > 0 { sb.WriteString(" OFFSET ") sb.WriteString(strconv.Itoa(q.offset)) } qry := sb.String() if q.debug { fmt.Println("***") fmt.Println(qry) fmt.Printf("%+v\n", q.args) fmt.Println("***") } return qry } func (q *selectQry) averageLen() int { n := 12 + len(q.table) // SELECT FROM for _, c := range q.fields { n += (len(c) + 2) * len(q.table) // columns with tablename.columnname } // JOIN if len(q.join) > 0 { for _, c := range q.join { n += len(c) + 2 // with whitespace } } // WHERE if len(q.where) > 0 { n += 7 + len(q.args)*5 // WHERE with 2 sapces and each args roughly with min of 5 char } // GROUP BY if len(q.groupBy) > 0 { n += 10 // GROUP BY for _, c := range q.groupBy { n += len(c) + 2 // one command and a whitespace } } // ORDER BY if len(q.orderBy) > 0 { n += 10 // ORDER BY for _, c := range q.orderBy { n += len(c) + 2 // one command and a whitespace } } // LIMIT if q.limit > 0 { n += 10 // LIMIT } // OFFSET if q.offset > 0 { n += 10 // OFFSET } return n }