forked from go/pgm
Audit with AI
This commit is contained in:
82
pgm_field.go
82
pgm_field.go
@@ -1,10 +1,49 @@
|
||||
package pgm
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Field related to a table
|
||||
type Field string
|
||||
|
||||
var (
|
||||
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
|
||||
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
|
||||
validDateTruncLevels = map[string]bool{
|
||||
"microseconds": true,
|
||||
"milliseconds": true,
|
||||
"second": true,
|
||||
"minute": true,
|
||||
"hour": true,
|
||||
"day": true,
|
||||
"week": true,
|
||||
"month": true,
|
||||
"quarter": true,
|
||||
"year": true,
|
||||
"decade": true,
|
||||
"century": true,
|
||||
"millennium": true,
|
||||
}
|
||||
)
|
||||
|
||||
// validateSQLIdentifier checks if a string is a valid SQL identifier
|
||||
func validateSQLIdentifier(s string) error {
|
||||
if !sqlIdentifierRegex.MatchString(s) {
|
||||
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// escapeSQLString escapes single quotes in a string for SQL
|
||||
func escapeSQLString(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
func (f Field) Name() string {
|
||||
return strings.Split(string(f), ".")[1]
|
||||
}
|
||||
@@ -18,16 +57,28 @@ func (f Field) Count() Field {
|
||||
return Field("COUNT(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// ConcatWs creates a CONCAT_WS SQL function.
|
||||
// SECURITY: The sep parameter should only be a constant string, not user input.
|
||||
// Single quotes in sep will be escaped automatically.
|
||||
func ConcatWs(sep string, fields ...Field) Field {
|
||||
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
|
||||
}
|
||||
|
||||
// StringAgg creates a STRING_AGG SQL function.
|
||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAgg(exp, sep string) Field {
|
||||
return Field("string_agg(" + exp + ",'" + sep + "')")
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(" + exp + ",'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAggCast(exp, sep string) Field {
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringEscape will wrap field with:
|
||||
@@ -114,6 +165,12 @@ func (f Field) RowNumberDescPartionBy(partition Field, as string) Field {
|
||||
}
|
||||
|
||||
func rowNumber(f, partition *Field, isAsc bool, as string) Field {
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if as != "" {
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
|
||||
}
|
||||
}
|
||||
var orderBy string
|
||||
if isAsc {
|
||||
orderBy = " ASC"
|
||||
@@ -150,10 +207,25 @@ func (f Field) IsNotNull() Conditioner {
|
||||
// - day, week (Monday start), month, quarter, year
|
||||
// - decade, century, millennium
|
||||
func (f Field) DateTrunc(level, as string) Field {
|
||||
return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as)
|
||||
// Validate level parameter against allowed values
|
||||
if !validDateTruncLevels[strings.ToLower(level)] {
|
||||
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
|
||||
}
|
||||
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
|
||||
}
|
||||
|
||||
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
|
||||
}
|
||||
|
||||
func (f Field) TsRank(fieldName, as string) Field {
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
|
||||
}
|
||||
|
||||
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user