Files
pgm/pgm_field.go

341 lines
9.3 KiB
Go
Raw Normal View History

package pgm
2025-11-16 11:37:02 +05:30
import (
"fmt"
"regexp"
"strings"
)
// Field related to a table
type Field string
2025-11-16 11:37:02 +05:30
var (
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
validDateTruncLevels = map[string]bool{
"microseconds": true,
"milliseconds": true,
"second": true,
"minute": true,
"hour": true,
"day": true,
"week": true,
"month": true,
"quarter": true,
"year": true,
"decade": true,
"century": true,
"millennium": true,
}
)
// validateSQLIdentifier checks if a string is a valid SQL identifier
func validateSQLIdentifier(s string) error {
if !sqlIdentifierRegex.MatchString(s) {
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
}
return nil
}
// escapeSQLString escapes single quotes in a string for SQL
func escapeSQLString(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
func (f Field) Name() string {
2025-11-16 16:21:35 +05:30
s := string(f)
idx := strings.LastIndexByte(s, '.')
if idx == -1 {
return s // Return as-is if no dot
}
return s[idx+1:] // Return part after last dot
}
func (f Field) String() string {
return string(f)
}
2025-10-18 14:43:42 +05:30
// Count function wrapped field
func (f Field) Count() Field {
return Field("COUNT(" + f.String() + ")")
}
2025-11-16 11:37:02 +05:30
// ConcatWs creates a CONCAT_WS SQL function.
// SECURITY: The sep parameter should only be a constant string, not user input.
// Single quotes in sep will be escaped automatically.
func ConcatWs(sep string, fields ...Field) Field {
2025-11-16 11:37:02 +05:30
escapedSep := escapeSQLString(sep)
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
}
2025-11-16 11:37:02 +05:30
// StringAgg creates a STRING_AGG SQL function.
2025-11-16 16:21:35 +05:30
// SECURITY: The field parameter provides type safety by accepting only Field type.
2025-11-16 11:37:02 +05:30
// The sep parameter should only be a constant string. Single quotes will be escaped.
2025-11-16 16:21:35 +05:30
func StringAgg(field Field, sep string) Field {
2025-11-16 11:37:02 +05:30
escapedSep := escapeSQLString(sep)
2025-11-16 16:21:35 +05:30
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
}
2025-11-16 11:37:02 +05:30
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
2025-11-16 16:21:35 +05:30
// SECURITY: The field parameter provides type safety by accepting only Field type.
2025-11-16 11:37:02 +05:30
// The sep parameter should only be a constant string. Single quotes will be escaped.
2025-11-16 16:21:35 +05:30
func StringAggCast(field Field, sep string) Field {
2025-11-16 11:37:02 +05:30
escapedSep := escapeSQLString(sep)
2025-11-16 16:21:35 +05:30
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
}
2025-10-18 14:43:42 +05:30
// StringEscape will wrap field with:
//
// COALESCE(field, ”)
func (f Field) StringEscape() Field {
return Field("COALESCE(" + f.String() + ", '')")
}
2025-10-18 14:43:42 +05:30
// NumberEscape will wrap field with:
//
// COALESCE(field, 0)
func (f Field) NumberEscape() Field {
return Field("COALESCE(" + f.String() + ", 0)")
}
2025-10-18 14:43:42 +05:30
// BooleanEscape will wrap field with:
//
// COALESCE(field, FALSE)
func (f Field) BooleanEscape() Field {
return Field("COALESCE(" + f.String() + ", FALSE)")
}
2025-10-18 14:43:42 +05:30
// Avg function wrapped field
func (f Field) Avg() Field {
return Field("AVG(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Sum function wrapped field
func (f Field) Sum() Field {
return Field("SUM(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Max function wrapped field
func (f Field) Max() Field {
return Field("MAX(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Min function wrapped field
func (f Field) Min() Field {
return Field("Min(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Lower function wrapped field
func (f Field) Lower() Field {
return Field("LOWER(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Upper function wrapped field
func (f Field) Upper() Field {
return Field("UPPER(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Trim function wrapped field
func (f Field) Trim() Field {
return Field("TRIM(" + f.String() + ")")
}
2025-10-18 14:43:42 +05:30
// Asc suffixed field, supposed to be used with order by
func (f Field) Asc() Field {
return Field(f.String() + " ASC")
}
2025-10-18 14:43:42 +05:30
// Desc suffixed field, supposed to be used with order by
func (f Field) Desc() Field {
return Field(f.String() + " DESC")
}
func (f Field) RowNumber(as string, extraOrderBy ...Field) Field {
return rowNumber(&f, nil, true, as, extraOrderBy...)
2025-10-21 00:27:00 +05:30
}
func (f Field) RowNumberDesc(as string, extraOrderBy ...Field) Field {
return rowNumber(&f, nil, false, as, extraOrderBy...)
2025-10-21 00:27:00 +05:30
}
// RowNumberPartionBy in ascending order
func (f Field) RowNumberPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
return rowNumber(&f, &partition, true, as, extraOrderBy...)
2025-10-21 00:27:00 +05:30
}
func (f Field) RowNumberDescPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
return rowNumber(&f, &partition, false, as, extraOrderBy...)
2025-10-21 00:27:00 +05:30
}
func rowNumber(f, partition *Field, isAsc bool, as string, extraOrderBy ...Field) Field {
2025-11-16 11:37:02 +05:30
// Validate as parameter is a valid SQL identifier
if as != "" {
if err := validateSQLIdentifier(as); err != nil {
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
}
}
2025-10-21 00:27:00 +05:30
var orderBy string
if isAsc {
orderBy = " ASC"
} else {
orderBy = " DESC"
}
if as == "" {
as = "row_number"
}
col := f.String()
// Build ORDER BY clause with primary field and extra fields
sb := getSB()
defer putSB(sb)
sb.WriteString(col)
sb.WriteString(orderBy)
// Add extra ORDER BY fields
for _, extra := range extraOrderBy {
sb.WriteString(", ")
sb.WriteString(extra.String())
}
orderByClause := sb.String()
2025-10-21 00:27:00 +05:30
if partition != nil {
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + orderByClause + ") AS " + as)
2025-10-21 00:27:00 +05:30
}
return Field("ROW_NUMBER() OVER (ORDER BY " + orderByClause + ") AS " + as)
2025-10-21 00:27:00 +05:30
}
func (f Field) IsNull() Conditioner {
col := f.String()
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
}
func (f Field) IsNotNull() Conditioner {
col := f.String()
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
}
2025-10-21 16:45:43 +05:30
// DateTrunc will truncate date or timestamp to specified level of precision
//
// Level values:
// - microseconds, milliseconds, second, minute, hour
// - day, week (Monday start), month, quarter, year
// - decade, century, millennium
func (f Field) DateTrunc(level, as string) Field {
2025-11-16 11:37:02 +05:30
// Validate level parameter against allowed values
if !validDateTruncLevels[strings.ToLower(level)] {
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
}
// Validate as parameter is a valid SQL identifier
if err := validateSQLIdentifier(as); err != nil {
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
}
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
2025-10-21 16:45:43 +05:30
}
2025-11-08 14:29:54 +05:30
func (f Field) TsRank(fieldName, as string) Field {
2025-11-16 11:37:02 +05:30
// Validate as parameter is a valid SQL identifier
if err := validateSQLIdentifier(as); err != nil {
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
}
2025-11-08 14:29:54 +05:30
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
}
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
func (f Field) EqFold(val string) Conditioner {
col := f.String()
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
}
// Eq is equal
func (f Field) Eq(val any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
}
func (f Field) NotEq(val any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
}
func (f Field) Gt(val any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
}
func (f Field) Lt(val any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5}
}
func (f Field) Gte(val any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
}
func (f Field) Lte(val any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5}
}
func (f Field) Like(val string) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
}
func (f Field) LikeFold(val string) Conditioner {
col := f.String()
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
}
// ILIKE is case-insensitive
func (f Field) ILike(val string) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
}
2025-10-18 14:43:42 +05:30
func (f Field) Any(val ...any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " = ANY($", action: CondActionNeedToClose, len: len(col) + 5}
}
2025-10-18 14:43:42 +05:30
func (f Field) NotAny(val ...any) Conditioner {
col := f.String()
return &Cond{Field: col, Val: val, op: " != ANY($", action: CondActionNeedToClose, len: len(col) + 5}
}
// NotInSubQuery using ANY
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
col := f.String()
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
}
func (f Field) TsQuery(as string) Conditioner {
col := f.String()
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
2025-10-18 14:43:42 +05:30
}
func joinFileds(fields []Field) string {
sb := getSB()
defer putSB(sb)
for i, f := range fields {
if i == 0 {
sb.WriteString(f.String())
} else {
sb.WriteString(", ")
sb.WriteString(f.String())
}
}
return sb.String()
}