forked from go/pgm
- Updated RowNumber, RowNumberDesc, RowNumberPartionBy, and RowNumberDescPartionBy to accept variadic extraOrderBy parameters - Fixed bug in RowNumberDesc that was incorrectly using ASC instead of DESC - Enhanced rowNumber internal function to build ORDER BY clause with primary field and additional fields - Backwards compatible - existing code continues to work without extra fields - Added CLAUDE.md documentation for future Claude Code instances
341 lines
9.3 KiB
Go
341 lines
9.3 KiB
Go
package pgm
|
|
|
|
import (
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// Field related to a table
|
|
type Field string
|
|
|
|
var (
|
|
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
|
|
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
|
|
|
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
|
|
validDateTruncLevels = map[string]bool{
|
|
"microseconds": true,
|
|
"milliseconds": true,
|
|
"second": true,
|
|
"minute": true,
|
|
"hour": true,
|
|
"day": true,
|
|
"week": true,
|
|
"month": true,
|
|
"quarter": true,
|
|
"year": true,
|
|
"decade": true,
|
|
"century": true,
|
|
"millennium": true,
|
|
}
|
|
)
|
|
|
|
// validateSQLIdentifier checks if a string is a valid SQL identifier
|
|
func validateSQLIdentifier(s string) error {
|
|
if !sqlIdentifierRegex.MatchString(s) {
|
|
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// escapeSQLString escapes single quotes in a string for SQL
|
|
func escapeSQLString(s string) string {
|
|
return strings.ReplaceAll(s, "'", "''")
|
|
}
|
|
|
|
func (f Field) Name() string {
|
|
s := string(f)
|
|
idx := strings.LastIndexByte(s, '.')
|
|
if idx == -1 {
|
|
return s // Return as-is if no dot
|
|
}
|
|
return s[idx+1:] // Return part after last dot
|
|
}
|
|
|
|
func (f Field) String() string {
|
|
return string(f)
|
|
}
|
|
|
|
// Count function wrapped field
|
|
func (f Field) Count() Field {
|
|
return Field("COUNT(" + f.String() + ")")
|
|
}
|
|
|
|
// ConcatWs creates a CONCAT_WS SQL function.
|
|
// SECURITY: The sep parameter should only be a constant string, not user input.
|
|
// Single quotes in sep will be escaped automatically.
|
|
func ConcatWs(sep string, fields ...Field) Field {
|
|
escapedSep := escapeSQLString(sep)
|
|
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
|
|
}
|
|
|
|
// StringAgg creates a STRING_AGG SQL function.
|
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
|
func StringAgg(field Field, sep string) Field {
|
|
escapedSep := escapeSQLString(sep)
|
|
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
|
|
}
|
|
|
|
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
|
func StringAggCast(field Field, sep string) Field {
|
|
escapedSep := escapeSQLString(sep)
|
|
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
|
|
}
|
|
|
|
// StringEscape will wrap field with:
|
|
//
|
|
// COALESCE(field, ”)
|
|
func (f Field) StringEscape() Field {
|
|
return Field("COALESCE(" + f.String() + ", '')")
|
|
}
|
|
|
|
// NumberEscape will wrap field with:
|
|
//
|
|
// COALESCE(field, 0)
|
|
func (f Field) NumberEscape() Field {
|
|
return Field("COALESCE(" + f.String() + ", 0)")
|
|
}
|
|
|
|
// BooleanEscape will wrap field with:
|
|
//
|
|
// COALESCE(field, FALSE)
|
|
func (f Field) BooleanEscape() Field {
|
|
return Field("COALESCE(" + f.String() + ", FALSE)")
|
|
}
|
|
|
|
// Avg function wrapped field
|
|
func (f Field) Avg() Field {
|
|
return Field("AVG(" + f.String() + ")")
|
|
}
|
|
|
|
// Sum function wrapped field
|
|
func (f Field) Sum() Field {
|
|
return Field("SUM(" + f.String() + ")")
|
|
}
|
|
|
|
// Max function wrapped field
|
|
func (f Field) Max() Field {
|
|
return Field("MAX(" + f.String() + ")")
|
|
}
|
|
|
|
// Min function wrapped field
|
|
func (f Field) Min() Field {
|
|
return Field("Min(" + f.String() + ")")
|
|
}
|
|
|
|
// Lower function wrapped field
|
|
func (f Field) Lower() Field {
|
|
return Field("LOWER(" + f.String() + ")")
|
|
}
|
|
|
|
// Upper function wrapped field
|
|
func (f Field) Upper() Field {
|
|
return Field("UPPER(" + f.String() + ")")
|
|
}
|
|
|
|
// Trim function wrapped field
|
|
func (f Field) Trim() Field {
|
|
return Field("TRIM(" + f.String() + ")")
|
|
}
|
|
|
|
// Asc suffixed field, supposed to be used with order by
|
|
func (f Field) Asc() Field {
|
|
return Field(f.String() + " ASC")
|
|
}
|
|
|
|
// Desc suffixed field, supposed to be used with order by
|
|
func (f Field) Desc() Field {
|
|
return Field(f.String() + " DESC")
|
|
}
|
|
|
|
func (f Field) RowNumber(as string, extraOrderBy ...Field) Field {
|
|
return rowNumber(&f, nil, true, as, extraOrderBy...)
|
|
}
|
|
|
|
func (f Field) RowNumberDesc(as string, extraOrderBy ...Field) Field {
|
|
return rowNumber(&f, nil, false, as, extraOrderBy...)
|
|
}
|
|
|
|
// RowNumberPartionBy in ascending order
|
|
func (f Field) RowNumberPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
|
|
return rowNumber(&f, &partition, true, as, extraOrderBy...)
|
|
}
|
|
|
|
func (f Field) RowNumberDescPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
|
|
return rowNumber(&f, &partition, false, as, extraOrderBy...)
|
|
}
|
|
|
|
func rowNumber(f, partition *Field, isAsc bool, as string, extraOrderBy ...Field) Field {
|
|
// Validate as parameter is a valid SQL identifier
|
|
if as != "" {
|
|
if err := validateSQLIdentifier(as); err != nil {
|
|
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
|
|
}
|
|
}
|
|
var orderBy string
|
|
if isAsc {
|
|
orderBy = " ASC"
|
|
} else {
|
|
orderBy = " DESC"
|
|
}
|
|
|
|
if as == "" {
|
|
as = "row_number"
|
|
}
|
|
|
|
col := f.String()
|
|
|
|
// Build ORDER BY clause with primary field and extra fields
|
|
sb := getSB()
|
|
defer putSB(sb)
|
|
|
|
sb.WriteString(col)
|
|
sb.WriteString(orderBy)
|
|
|
|
// Add extra ORDER BY fields
|
|
for _, extra := range extraOrderBy {
|
|
sb.WriteString(", ")
|
|
sb.WriteString(extra.String())
|
|
}
|
|
|
|
orderByClause := sb.String()
|
|
|
|
if partition != nil {
|
|
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + orderByClause + ") AS " + as)
|
|
}
|
|
|
|
return Field("ROW_NUMBER() OVER (ORDER BY " + orderByClause + ") AS " + as)
|
|
}
|
|
|
|
func (f Field) IsNull() Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
|
|
}
|
|
|
|
func (f Field) IsNotNull() Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
|
}
|
|
|
|
// DateTrunc will truncate date or timestamp to specified level of precision
|
|
//
|
|
// Level values:
|
|
// - microseconds, milliseconds, second, minute, hour
|
|
// - day, week (Monday start), month, quarter, year
|
|
// - decade, century, millennium
|
|
func (f Field) DateTrunc(level, as string) Field {
|
|
// Validate level parameter against allowed values
|
|
if !validDateTruncLevels[strings.ToLower(level)] {
|
|
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
|
|
}
|
|
|
|
// Validate as parameter is a valid SQL identifier
|
|
if err := validateSQLIdentifier(as); err != nil {
|
|
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
|
|
}
|
|
|
|
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
|
|
}
|
|
|
|
func (f Field) TsRank(fieldName, as string) Field {
|
|
// Validate as parameter is a valid SQL identifier
|
|
if err := validateSQLIdentifier(as); err != nil {
|
|
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
|
|
}
|
|
|
|
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
|
}
|
|
|
|
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
|
func (f Field) EqFold(val string) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
}
|
|
|
|
// Eq is equal
|
|
func (f Field) Eq(val any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) NotEq(val any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) Gt(val any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) Lt(val any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) Gte(val any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) Lte(val any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) Like(val string) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
|
}
|
|
|
|
func (f Field) LikeFold(val string) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
}
|
|
|
|
// ILIKE is case-insensitive
|
|
func (f Field) ILike(val string) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) Any(val ...any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " = ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
}
|
|
|
|
func (f Field) NotAny(val ...any) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: val, op: " != ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
}
|
|
|
|
// NotInSubQuery using ANY
|
|
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
|
}
|
|
|
|
func (f Field) TsQuery(as string) Conditioner {
|
|
col := f.String()
|
|
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
|
|
}
|
|
|
|
func joinFileds(fields []Field) string {
|
|
sb := getSB()
|
|
defer putSB(sb)
|
|
for i, f := range fields {
|
|
if i == 0 {
|
|
sb.WriteString(f.String())
|
|
} else {
|
|
sb.WriteString(", ")
|
|
sb.WriteString(f.String())
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|