Audit with AI

This commit is contained in:
2025-11-16 11:37:02 +05:30
parent 551e2123bc
commit 29cddb6389
24 changed files with 1286 additions and 79 deletions

338
pgm_field_test.go Normal file
View File

@@ -0,0 +1,338 @@
package pgm
import (
"strings"
"testing"
)
// TestValidateSQLIdentifier tests SQL identifier validation
func TestValidateSQLIdentifier(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{"valid simple", "column_name", false},
{"valid with numbers", "column123", false},
{"valid underscore", "_private", false},
{"valid mixed", "my_Column_123", false},
{"invalid with space", "column name", true},
{"invalid with dash", "column-name", true},
{"invalid with dot", "table.column", true},
{"invalid with quote", "column'name", true},
{"invalid with semicolon", "column;DROP", true},
{"invalid starts with number", "123column", true},
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
{"empty string", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSQLIdentifier(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}
// TestEscapeSQLString tests SQL string escaping
func TestEscapeSQLString(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"no quotes", "hello", "hello"},
{"single quote", "hello'world", "hello''world"},
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
{"only quote", "'", "''"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := escapeSQLString(tt.input)
if result != tt.expected {
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
tests := []struct {
name string
sep string
fields []Field
contains string
}{
{
name: "safe separator",
sep: ", ",
fields: []Field{"col1", "col2"},
contains: "concat_ws(', ',col1, col2)",
},
{
name: "escaped quotes",
sep: "', (SELECT password FROM users), '",
fields: []Field{"col1"},
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
},
{
name: "single quote",
sep: "'",
fields: []Field{"col1"},
contains: "concat_ws('''',col1)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConcatWs(tt.sep, tt.fields...)
if !strings.Contains(string(result), tt.contains) {
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
}
})
}
}
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
func TestStringAggSQLInjectionPrevention(t *testing.T) {
tests := []struct {
name string
exp string
sep string
contains string
}{
{
name: "safe parameters",
exp: "column_name",
sep: ", ",
contains: "string_agg(column_name,', ')",
},
{
name: "escaped quotes in separator",
sep: "'; DROP TABLE users; --",
exp: "col",
contains: "string_agg(col,'''; DROP TABLE users; --')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StringAgg(tt.exp, tt.sep)
if !strings.Contains(string(result), tt.contains) {
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
}
})
}
}
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
result := StringAggCast("column", "'; DROP TABLE")
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
if string(result) != expected {
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
}
}
// TestDateTruncValidation tests DateTrunc level validation
func TestDateTruncValidation(t *testing.T) {
field := Field("created_at")
validLevels := []string{
"microseconds", "milliseconds", "second", "minute", "hour",
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
}
for _, level := range validLevels {
t.Run("valid_"+level, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
}
}()
result := field.DateTrunc(level, "truncated")
if !strings.Contains(string(result), level) {
t.Errorf("DateTrunc result should contain level %q", level)
}
})
}
// Test case-insensitive
t.Run("case_insensitive", func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("DateTrunc should accept uppercase level")
}
}()
result := field.DateTrunc("MONTH", "truncated")
if !strings.Contains(strings.ToLower(string(result)), "month") {
t.Errorf("DateTrunc should normalize case")
}
})
}
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
func TestDateTruncInvalidLevel(t *testing.T) {
field := Field("created_at")
invalidLevels := []string{
"invalid",
"'; DROP TABLE users; --",
"day; DELETE FROM",
"",
}
for _, level := range invalidLevels {
t.Run("invalid_"+level, func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
}
}()
field.DateTrunc(level, "truncated")
})
}
}
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
func TestDateTruncInvalidAlias(t *testing.T) {
field := Field("created_at")
invalidAliases := []string{
"alias name",
"alias-name",
"'; DROP TABLE",
"123alias",
"alias.name",
}
for _, alias := range invalidAliases {
t.Run("invalid_alias_"+alias, func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
}
}()
field.DateTrunc("day", alias)
})
}
}
// TestTsRankValidation tests TsRank alias validation
func TestTsRankValidation(t *testing.T) {
field := Field("search_vector")
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
for _, alias := range validAliases {
t.Run("valid_"+alias, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("TsRank should not panic for valid alias %q", alias)
}
}()
result := field.TsRank("query", alias)
if !strings.Contains(string(result), alias) {
t.Errorf("TsRank result should contain alias %q", alias)
}
})
}
invalidAliases := []string{
"'; DROP TABLE",
"rank name",
"rank-name",
"123rank",
}
for _, alias := range invalidAliases {
t.Run("invalid_"+alias, func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("TsRank should panic for invalid alias %q", alias)
}
}()
field.TsRank("query", alias)
})
}
}
// TestRowNumberValidation tests ROW_NUMBER alias validation
func TestRowNumberValidation(t *testing.T) {
field := Field("id")
t.Run("valid_alias", func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("RowNumber should not panic for valid alias: %v", r)
}
}()
result := field.RowNumber("row_num")
if !strings.Contains(string(result), "row_num") {
t.Errorf("RowNumber result should contain alias")
}
})
t.Run("invalid_alias", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("RowNumber should panic for invalid alias")
}
}()
field.RowNumber("'; DROP TABLE")
})
}
// TestSQLInjectionAttackVectors tests common SQL injection patterns
func TestSQLInjectionAttackVectors(t *testing.T) {
attacks := []string{
"'; DROP TABLE users; --",
"' OR '1'='1",
"'; DELETE FROM users WHERE '1'='1",
"1'; UPDATE users SET password='hacked'; --",
"admin'--",
"' UNION SELECT * FROM passwords--",
}
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
field := Field("created_at")
for _, attack := range attacks {
func() {
defer func() {
if r := recover(); r == nil {
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
}
}()
field.DateTrunc("day", attack)
}()
}
})
t.Run("TsRank_alias_protection", func(t *testing.T) {
field := Field("search_vector")
for _, attack := range attacks {
func() {
defer func() {
if r := recover(); r == nil {
t.Errorf("TsRank should prevent SQL injection: %q", attack)
}
}()
field.TsRank("query", attack)
}()
}
})
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
for _, attack := range attacks {
result := ConcatWs(attack, Field("col1"))
// Check that quotes are escaped (doubled)
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
}
}
})
}