Audit with AI

This commit is contained in:
2025-11-16 11:37:02 +05:30
parent 551e2123bc
commit 29cddb6389
24 changed files with 1286 additions and 79 deletions

View File

@@ -1,10 +1,49 @@
package pgm
import "strings"
import (
"fmt"
"regexp"
"strings"
)
// Field related to a table
type Field string
var (
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
validDateTruncLevels = map[string]bool{
"microseconds": true,
"milliseconds": true,
"second": true,
"minute": true,
"hour": true,
"day": true,
"week": true,
"month": true,
"quarter": true,
"year": true,
"decade": true,
"century": true,
"millennium": true,
}
)
// validateSQLIdentifier checks if a string is a valid SQL identifier
func validateSQLIdentifier(s string) error {
if !sqlIdentifierRegex.MatchString(s) {
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
}
return nil
}
// escapeSQLString escapes single quotes in a string for SQL
func escapeSQLString(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
func (f Field) Name() string {
return strings.Split(string(f), ".")[1]
}
@@ -18,16 +57,28 @@ func (f Field) Count() Field {
return Field("COUNT(" + f.String() + ")")
}
// ConcatWs creates a CONCAT_WS SQL function.
// SECURITY: The sep parameter should only be a constant string, not user input.
// Single quotes in sep will be escaped automatically.
func ConcatWs(sep string, fields ...Field) Field {
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
escapedSep := escapeSQLString(sep)
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
}
// StringAgg creates a STRING_AGG SQL function.
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
// The sep parameter should only be a constant string. Single quotes will be escaped.
func StringAgg(exp, sep string) Field {
return Field("string_agg(" + exp + ",'" + sep + "')")
escapedSep := escapeSQLString(sep)
return Field("string_agg(" + exp + ",'" + escapedSep + "')")
}
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
// The sep parameter should only be a constant string. Single quotes will be escaped.
func StringAggCast(exp, sep string) Field {
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
escapedSep := escapeSQLString(sep)
return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')")
}
// StringEscape will wrap field with:
@@ -114,6 +165,12 @@ func (f Field) RowNumberDescPartionBy(partition Field, as string) Field {
}
func rowNumber(f, partition *Field, isAsc bool, as string) Field {
// Validate as parameter is a valid SQL identifier
if as != "" {
if err := validateSQLIdentifier(as); err != nil {
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
}
}
var orderBy string
if isAsc {
orderBy = " ASC"
@@ -150,10 +207,25 @@ func (f Field) IsNotNull() Conditioner {
// - day, week (Monday start), month, quarter, year
// - decade, century, millennium
func (f Field) DateTrunc(level, as string) Field {
return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as)
// Validate level parameter against allowed values
if !validDateTruncLevels[strings.ToLower(level)] {
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
}
// Validate as parameter is a valid SQL identifier
if err := validateSQLIdentifier(as); err != nil {
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
}
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
}
func (f Field) TsRank(fieldName, as string) Field {
// Validate as parameter is a valid SQL identifier
if err := validateSQLIdentifier(as); err != nil {
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
}
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
}