package pgm import ( "fmt" "regexp" "strings" ) // Field related to a table type Field string var ( // sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only) sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) // validDateTruncLevels contains all allowed DATE_TRUNC precision levels validDateTruncLevels = map[string]bool{ "microseconds": true, "milliseconds": true, "second": true, "minute": true, "hour": true, "day": true, "week": true, "month": true, "quarter": true, "year": true, "decade": true, "century": true, "millennium": true, } ) // validateSQLIdentifier checks if a string is a valid SQL identifier func validateSQLIdentifier(s string) error { if !sqlIdentifierRegex.MatchString(s) { return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s) } return nil } // escapeSQLString escapes single quotes in a string for SQL func escapeSQLString(s string) string { return strings.ReplaceAll(s, "'", "''") } func (f Field) Name() string { return strings.Split(string(f), ".")[1] } func (f Field) String() string { return string(f) } // Count function wrapped field func (f Field) Count() Field { return Field("COUNT(" + f.String() + ")") } // ConcatWs creates a CONCAT_WS SQL function. // SECURITY: The sep parameter should only be a constant string, not user input. // Single quotes in sep will be escaped automatically. func ConcatWs(sep string, fields ...Field) Field { escapedSep := escapeSQLString(sep) return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")") } // StringAgg creates a STRING_AGG SQL function. // SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL. // The sep parameter should only be a constant string. Single quotes will be escaped. func StringAgg(exp, sep string) Field { escapedSep := escapeSQLString(sep) return Field("string_agg(" + exp + ",'" + escapedSep + "')") } // StringAggCast creates a STRING_AGG SQL function with cast to varchar. // SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL. // The sep parameter should only be a constant string. Single quotes will be escaped. func StringAggCast(exp, sep string) Field { escapedSep := escapeSQLString(sep) return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')") } // StringEscape will wrap field with: // // COALESCE(field, ”) func (f Field) StringEscape() Field { return Field("COALESCE(" + f.String() + ", '')") } // NumberEscape will wrap field with: // // COALESCE(field, 0) func (f Field) NumberEscape() Field { return Field("COALESCE(" + f.String() + ", 0)") } // BooleanEscape will wrap field with: // // COALESCE(field, FALSE) func (f Field) BooleanEscape() Field { return Field("COALESCE(" + f.String() + ", FALSE)") } // Avg function wrapped field func (f Field) Avg() Field { return Field("AVG(" + f.String() + ")") } // Sum function wrapped field func (f Field) Sum() Field { return Field("SUM(" + f.String() + ")") } // Max function wrapped field func (f Field) Max() Field { return Field("MAX(" + f.String() + ")") } // Min function wrapped field func (f Field) Min() Field { return Field("Min(" + f.String() + ")") } // Lower function wrapped field func (f Field) Lower() Field { return Field("LOWER(" + f.String() + ")") } // Upper function wrapped field func (f Field) Upper() Field { return Field("UPPER(" + f.String() + ")") } // Trim function wrapped field func (f Field) Trim() Field { return Field("TRIM(" + f.String() + ")") } // Asc suffixed field, supposed to be used with order by func (f Field) Asc() Field { return Field(f.String() + " ASC") } // Desc suffixed field, supposed to be used with order by func (f Field) Desc() Field { return Field(f.String() + " DESC") } func (f Field) RowNumber(as string) Field { return rowNumber(&f, nil, true, as) } func (f Field) RowNumberDesc(as string) Field { return rowNumber(&f, nil, true, as) } // RowNumberPartionBy in ascending order func (f Field) RowNumberPartionBy(partition Field, as string) Field { return rowNumber(&f, &partition, true, as) } func (f Field) RowNumberDescPartionBy(partition Field, as string) Field { return rowNumber(&f, &partition, false, as) } func rowNumber(f, partition *Field, isAsc bool, as string) Field { // Validate as parameter is a valid SQL identifier if as != "" { if err := validateSQLIdentifier(as); err != nil { panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err)) } } var orderBy string if isAsc { orderBy = " ASC" } else { orderBy = " DESC" } if as == "" { as = "row_number" } col := f.String() if partition != nil { return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + col + orderBy + ") AS " + as) } return Field("ROW_NUMBER() OVER (ORDER BY " + col + orderBy + ") AS " + as) } func (f Field) IsNull() Conditioner { col := f.String() return &Cond{Field: col, op: " IS NULL", len: len(col) + 8} } func (f Field) IsNotNull() Conditioner { col := f.String() return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12} } // DateTrunc will truncate date or timestamp to specified level of precision // // Level values: // - microseconds, milliseconds, second, minute, hour // - day, week (Monday start), month, quarter, year // - decade, century, millennium func (f Field) DateTrunc(level, as string) Field { // Validate level parameter against allowed values if !validDateTruncLevels[strings.ToLower(level)] { panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level)) } // Validate as parameter is a valid SQL identifier if err := validateSQLIdentifier(as); err != nil { panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err)) } return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as) } func (f Field) TsRank(fieldName, as string) Field { // Validate as parameter is a valid SQL identifier if err := validateSQLIdentifier(as); err != nil { panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err)) } return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as) } // EqualFold will use LOWER(column_name) = LOWER(val) for comparision func (f Field) EqFold(val string) Conditioner { col := f.String() return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5} } // Eq is equal func (f Field) Eq(val any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5} } func (f Field) NotEq(val any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5} } func (f Field) Gt(val any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5} } func (f Field) Lt(val any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5} } func (f Field) Gte(val any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5} } func (f Field) Lte(val any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5} } func (f Field) Like(val string) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5} } func (f Field) LikeFold(val string) Conditioner { col := f.String() return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5} } // ILIKE is case-insensitive func (f Field) ILike(val string) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5} } func (f Field) Any(val ...any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " = ANY($", action: CondActionNeedToClose, len: len(col) + 5} } func (f Field) NotAny(val ...any) Conditioner { col := f.String() return &Cond{Field: col, Val: val, op: " != ANY($", action: CondActionNeedToClose, len: len(col) + 5} } // NotInSubQuery using ANY func (f Field) NotInSubQuery(qry WhereClause) Conditioner { col := f.String() return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery} } func (f Field) TsQuery(as string) Conditioner { col := f.String() return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5} } func joinFileds(fields []Field) string { sb := getSB() defer putSB(sb) for i, f := range fields { if i == 0 { sb.WriteString(f.String()) } else { sb.WriteString(", ") sb.WriteString(f.String()) } } return sb.String() }