// Patial Tech. // Author, Ankit Patial package pgm import ( "context" "errors" "fmt" "strconv" "strings" "github.com/jackc/pgx/v5" ) type ( selectQry struct { table string fields []Field args []any join []string where []Conditioner groupBy []Field having []Conditioner orderBy []Field limit int offset int debug bool } CondAction uint8 Cond struct { Val any op string Field string len int action CondAction } CondGroup struct { op string cond []Conditioner } ) // Contdition actions const ( CondActionNothing CondAction = iota CondActionNeedToClose CondActionSubQuery ) // Select clause func (t Table) Select(field ...Field) SelectClause { qb := &selectQry{ table: t.Name, debug: t.debug, fields: field, } return qb } func (q *selectQry) Join(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "JOIN", t1Field, t2Field, cond...) } func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "LEFT JOIN", t1Field, t2Field, cond...) } func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "RIGHT JOIN", t1Field, t2Field, cond...) } func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { return q.buildJoin(t, "FULL JOIN", t1Field, t2Field, cond...) } func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, cond ...Conditioner) SelectClause { str := joinKW + " " + t.Name + " ON " + t1Field.String() + " = " + t2Field.String() if len(cond) == 0 { // Join with no condition q.join = append(q.join, str) return q } // Join has condition(s) sb := getSB() defer putSB(sb) sb.Grow(len(str) * 2) sb.WriteString(str + " AND ") var argIdx int for i, c := range cond { argIdx = len(q.args) if i > 0 { sb.WriteString(" AND ") } sb.WriteString(c.Condition(&q.args, argIdx)) } q.join = append(q.join, sb.String()) return q } func (q *selectQry) CrossJoin(t Table) SelectClause { q.join = append(q.join, "CROSS JOIN "+t.Name) return q } func (q *selectQry) Where(cond ...Conditioner) AfterWhere { q.where = append(q.where, cond...) return q } func (q *selectQry) OrderBy(fields ...Field) AfterOrderBy { q.orderBy = fields return q } func (q *selectQry) GroupBy(fields ...Field) AfterGroupBy { q.groupBy = fields return q } func (q *selectQry) Having(cond ...Conditioner) AfterHaving { q.having = append(q.having, cond...) return q } func (q *selectQry) Limit(v int) AfterLimit { q.limit = v return q } func (q *selectQry) Offset(v int) AfterOffset { q.offset = v return q } func (q *selectQry) First(ctx context.Context, dest ...any) error { return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...) } func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error { return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...) } func (q *selectQry) All(ctx context.Context, row RowsCb) error { rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...) if errors.Is(err, pgx.ErrNoRows) { return ErrNoRows } defer rows.Close() for rows.Next() { if err := row(rows); err != nil { return err } } return nil } func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error { rows, err := tx.Query(ctx, q.String(), q.args...) if errors.Is(err, pgx.ErrNoRows) { return ErrNoRows } defer rows.Close() for rows.Next() { if err := row(rows); err != nil { return err } } return nil } func (q *selectQry) raw(prefixArgs []any) (string, []any) { q.args = append(prefixArgs, q.args...) return q.String(), q.args } func (q *selectQry) String() string { sb := getSB() defer putSB(sb) sb.Grow(q.averageLen()) // SELECT sb.WriteString("SELECT ") sb.WriteString(joinFileds(q.fields)) sb.WriteString(" FROM " + q.table) // JOIN if len(q.join) > 0 { sb.WriteString(" " + strings.Join(q.join, " ")) } // WHERE if len(q.where) > 0 { sb.WriteString(" WHERE ") var argIdx int for i, c := range q.where { argIdx = len(q.args) if i > 0 { sb.WriteString(" AND ") } sb.WriteString(c.Condition(&q.args, argIdx)) } } // GROUP BY if len(q.groupBy) > 0 { sb.WriteString(" GROUP BY ") sb.WriteString(joinFileds(q.groupBy)) } // HAVING if len(q.having) > 0 { sb.WriteString(" HAVING ") var argIdx int for i, c := range q.having { argIdx = len(q.args) if i > 0 { sb.WriteString(" AND ") } sb.WriteString(c.Condition(&q.args, argIdx)) } } // ORDER BY if len(q.orderBy) > 0 { sb.WriteString(" ORDER BY ") sb.WriteString(joinFileds(q.orderBy)) } // LIMIT if q.limit > 0 { sb.WriteString(" LIMIT ") sb.WriteString(strconv.Itoa(q.limit)) } // OFFSET if q.offset > 0 { sb.WriteString(" OFFSET ") sb.WriteString(strconv.Itoa(q.offset)) } qry := sb.String() if q.debug { fmt.Println("***") fmt.Println(qry) fmt.Printf("%+v\n", q.args) fmt.Println("***") } return qry } func (q *selectQry) averageLen() int { n := 12 + len(q.table) // SELECT FROM for _, c := range q.fields { n += (len(c) + 2) * len(q.table) // columns with tablename.columnname } // JOIN if len(q.join) > 0 { for _, c := range q.join { n += len(c) + 2 // with whitespace } } // WHERE if len(q.where) > 0 { n += 7 + len(q.args)*5 // WHERE with 2 sapces and each args roughly with min of 5 char } // GROUP BY if len(q.groupBy) > 0 { n += 10 // GROUP BY for _, c := range q.groupBy { n += len(c) + 2 // one command and a whitespace } } // ORDER BY if len(q.orderBy) > 0 { n += 10 // ORDER BY for _, c := range q.orderBy { n += len(c) + 2 // one command and a whitespace } } // LIMIT if q.limit > 0 { n += 10 // LIMIT } // OFFSET if q.offset > 0 { n += 10 // OFFSET } return n }