Audit with AI
This commit is contained in:
338
pgm_field_test.go
Normal file
338
pgm_field_test.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestValidateSQLIdentifier tests SQL identifier validation
|
||||
func TestValidateSQLIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid simple", "column_name", false},
|
||||
{"valid with numbers", "column123", false},
|
||||
{"valid underscore", "_private", false},
|
||||
{"valid mixed", "my_Column_123", false},
|
||||
{"invalid with space", "column name", true},
|
||||
{"invalid with dash", "column-name", true},
|
||||
{"invalid with dot", "table.column", true},
|
||||
{"invalid with quote", "column'name", true},
|
||||
{"invalid with semicolon", "column;DROP", true},
|
||||
{"invalid starts with number", "123column", true},
|
||||
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
|
||||
{"empty string", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSQLIdentifier(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEscapeSQLString tests SQL string escaping
|
||||
func TestEscapeSQLString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"no quotes", "hello", "hello"},
|
||||
{"single quote", "hello'world", "hello''world"},
|
||||
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
|
||||
{"only quote", "'", "''"},
|
||||
{"empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := escapeSQLString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
|
||||
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sep string
|
||||
fields []Field
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe separator",
|
||||
sep: ", ",
|
||||
fields: []Field{"col1", "col2"},
|
||||
contains: "concat_ws(', ',col1, col2)",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
sep: "', (SELECT password FROM users), '",
|
||||
fields: []Field{"col1"},
|
||||
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
|
||||
},
|
||||
{
|
||||
name: "single quote",
|
||||
sep: "'",
|
||||
fields: []Field{"col1"},
|
||||
contains: "concat_ws('''',col1)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConcatWs(tt.sep, tt.fields...)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
|
||||
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exp string
|
||||
sep string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe parameters",
|
||||
exp: "column_name",
|
||||
sep: ", ",
|
||||
contains: "string_agg(column_name,', ')",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes in separator",
|
||||
sep: "'; DROP TABLE users; --",
|
||||
exp: "col",
|
||||
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StringAgg(tt.exp, tt.sep)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||
result := StringAggCast("column", "'; DROP TABLE")
|
||||
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||
if string(result) != expected {
|
||||
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDateTruncValidation tests DateTrunc level validation
|
||||
func TestDateTruncValidation(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
validLevels := []string{
|
||||
"microseconds", "milliseconds", "second", "minute", "hour",
|
||||
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
|
||||
}
|
||||
|
||||
for _, level := range validLevels {
|
||||
t.Run("valid_"+level, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
|
||||
}
|
||||
}()
|
||||
result := field.DateTrunc(level, "truncated")
|
||||
if !strings.Contains(string(result), level) {
|
||||
t.Errorf("DateTrunc result should contain level %q", level)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test case-insensitive
|
||||
t.Run("case_insensitive", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("DateTrunc should accept uppercase level")
|
||||
}
|
||||
}()
|
||||
result := field.DateTrunc("MONTH", "truncated")
|
||||
if !strings.Contains(strings.ToLower(string(result)), "month") {
|
||||
t.Errorf("DateTrunc should normalize case")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
|
||||
func TestDateTruncInvalidLevel(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
invalidLevels := []string{
|
||||
"invalid",
|
||||
"'; DROP TABLE users; --",
|
||||
"day; DELETE FROM",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, level := range invalidLevels {
|
||||
t.Run("invalid_"+level, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc(level, "truncated")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
|
||||
func TestDateTruncInvalidAlias(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
invalidAliases := []string{
|
||||
"alias name",
|
||||
"alias-name",
|
||||
"'; DROP TABLE",
|
||||
"123alias",
|
||||
"alias.name",
|
||||
}
|
||||
|
||||
for _, alias := range invalidAliases {
|
||||
t.Run("invalid_alias_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc("day", alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTsRankValidation tests TsRank alias validation
|
||||
func TestTsRankValidation(t *testing.T) {
|
||||
field := Field("search_vector")
|
||||
|
||||
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
|
||||
for _, alias := range validAliases {
|
||||
t.Run("valid_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("TsRank should not panic for valid alias %q", alias)
|
||||
}
|
||||
}()
|
||||
result := field.TsRank("query", alias)
|
||||
if !strings.Contains(string(result), alias) {
|
||||
t.Errorf("TsRank result should contain alias %q", alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
invalidAliases := []string{
|
||||
"'; DROP TABLE",
|
||||
"rank name",
|
||||
"rank-name",
|
||||
"123rank",
|
||||
}
|
||||
|
||||
for _, alias := range invalidAliases {
|
||||
t.Run("invalid_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("TsRank should panic for invalid alias %q", alias)
|
||||
}
|
||||
}()
|
||||
field.TsRank("query", alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRowNumberValidation tests ROW_NUMBER alias validation
|
||||
func TestRowNumberValidation(t *testing.T) {
|
||||
field := Field("id")
|
||||
|
||||
t.Run("valid_alias", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("RowNumber should not panic for valid alias: %v", r)
|
||||
}
|
||||
}()
|
||||
result := field.RowNumber("row_num")
|
||||
if !strings.Contains(string(result), "row_num") {
|
||||
t.Errorf("RowNumber result should contain alias")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_alias", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("RowNumber should panic for invalid alias")
|
||||
}
|
||||
}()
|
||||
field.RowNumber("'; DROP TABLE")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSQLInjectionAttackVectors tests common SQL injection patterns
|
||||
func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||
attacks := []string{
|
||||
"'; DROP TABLE users; --",
|
||||
"' OR '1'='1",
|
||||
"'; DELETE FROM users WHERE '1'='1",
|
||||
"1'; UPDATE users SET password='hacked'; --",
|
||||
"admin'--",
|
||||
"' UNION SELECT * FROM passwords--",
|
||||
}
|
||||
|
||||
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
for _, attack := range attacks {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc("day", attack)
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TsRank_alias_protection", func(t *testing.T) {
|
||||
field := Field("search_vector")
|
||||
for _, attack := range attacks {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("TsRank should prevent SQL injection: %q", attack)
|
||||
}
|
||||
}()
|
||||
field.TsRank("query", attack)
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
|
||||
for _, attack := range attacks {
|
||||
result := ConcatWs(attack, Field("col1"))
|
||||
// Check that quotes are escaped (doubled)
|
||||
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
|
||||
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user