forked from go/pgm
383 lines
9.6 KiB
Go
383 lines
9.6 KiB
Go
package pgm
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// TestValidateSQLIdentifier tests SQL identifier validation
|
|
func TestValidateSQLIdentifier(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
wantErr bool
|
|
}{
|
|
{"valid simple", "column_name", false},
|
|
{"valid with numbers", "column123", false},
|
|
{"valid underscore", "_private", false},
|
|
{"valid mixed", "my_Column_123", false},
|
|
{"invalid with space", "column name", true},
|
|
{"invalid with dash", "column-name", true},
|
|
{"invalid with dot", "table.column", true},
|
|
{"invalid with quote", "column'name", true},
|
|
{"invalid with semicolon", "column;DROP", true},
|
|
{"invalid starts with number", "123column", true},
|
|
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
|
|
{"empty string", "", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateSQLIdentifier(tt.input)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestEscapeSQLString tests SQL string escaping
|
|
func TestEscapeSQLString(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{"no quotes", "hello", "hello"},
|
|
{"single quote", "hello'world", "hello''world"},
|
|
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
|
|
{"only quote", "'", "''"},
|
|
{"empty string", "", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := escapeSQLString(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
|
|
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sep string
|
|
fields []Field
|
|
contains string
|
|
}{
|
|
{
|
|
name: "safe separator",
|
|
sep: ", ",
|
|
fields: []Field{"col1", "col2"},
|
|
contains: "concat_ws(', ',col1, col2)",
|
|
},
|
|
{
|
|
name: "escaped quotes",
|
|
sep: "', (SELECT password FROM users), '",
|
|
fields: []Field{"col1"},
|
|
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
|
|
},
|
|
{
|
|
name: "single quote",
|
|
sep: "'",
|
|
fields: []Field{"col1"},
|
|
contains: "concat_ws('''',col1)",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := ConcatWs(tt.sep, tt.fields...)
|
|
if !strings.Contains(string(result), tt.contains) {
|
|
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
|
|
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
field Field
|
|
sep string
|
|
contains string
|
|
}{
|
|
{
|
|
name: "safe parameters",
|
|
field: Field("column_name"),
|
|
sep: ", ",
|
|
contains: "string_agg(column_name,', ')",
|
|
},
|
|
{
|
|
name: "escaped quotes in separator",
|
|
sep: "'; DROP TABLE users; --",
|
|
field: Field("col"),
|
|
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := StringAgg(tt.field, tt.sep)
|
|
if !strings.Contains(string(result), tt.contains) {
|
|
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
|
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
|
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
|
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
|
if string(result) != expected {
|
|
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
|
}
|
|
}
|
|
|
|
// TestDateTruncValidation tests DateTrunc level validation
|
|
func TestDateTruncValidation(t *testing.T) {
|
|
field := Field("created_at")
|
|
|
|
validLevels := []string{
|
|
"microseconds", "milliseconds", "second", "minute", "hour",
|
|
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
|
|
}
|
|
|
|
for _, level := range validLevels {
|
|
t.Run("valid_"+level, func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
|
|
}
|
|
}()
|
|
result := field.DateTrunc(level, "truncated")
|
|
if !strings.Contains(string(result), level) {
|
|
t.Errorf("DateTrunc result should contain level %q", level)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test case-insensitive
|
|
t.Run("case_insensitive", func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("DateTrunc should accept uppercase level")
|
|
}
|
|
}()
|
|
result := field.DateTrunc("MONTH", "truncated")
|
|
if !strings.Contains(strings.ToLower(string(result)), "month") {
|
|
t.Errorf("DateTrunc should normalize case")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
|
|
func TestDateTruncInvalidLevel(t *testing.T) {
|
|
field := Field("created_at")
|
|
|
|
invalidLevels := []string{
|
|
"invalid",
|
|
"'; DROP TABLE users; --",
|
|
"day; DELETE FROM",
|
|
"",
|
|
}
|
|
|
|
for _, level := range invalidLevels {
|
|
t.Run("invalid_"+level, func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
|
|
}
|
|
}()
|
|
field.DateTrunc(level, "truncated")
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
|
|
func TestDateTruncInvalidAlias(t *testing.T) {
|
|
field := Field("created_at")
|
|
|
|
invalidAliases := []string{
|
|
"alias name",
|
|
"alias-name",
|
|
"'; DROP TABLE",
|
|
"123alias",
|
|
"alias.name",
|
|
}
|
|
|
|
for _, alias := range invalidAliases {
|
|
t.Run("invalid_alias_"+alias, func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
|
|
}
|
|
}()
|
|
field.DateTrunc("day", alias)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTsRankValidation tests TsRank alias validation
|
|
func TestTsRankValidation(t *testing.T) {
|
|
field := Field("search_vector")
|
|
|
|
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
|
|
for _, alias := range validAliases {
|
|
t.Run("valid_"+alias, func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("TsRank should not panic for valid alias %q", alias)
|
|
}
|
|
}()
|
|
result := field.TsRank("query", alias)
|
|
if !strings.Contains(string(result), alias) {
|
|
t.Errorf("TsRank result should contain alias %q", alias)
|
|
}
|
|
})
|
|
}
|
|
|
|
invalidAliases := []string{
|
|
"'; DROP TABLE",
|
|
"rank name",
|
|
"rank-name",
|
|
"123rank",
|
|
}
|
|
|
|
for _, alias := range invalidAliases {
|
|
t.Run("invalid_"+alias, func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("TsRank should panic for invalid alias %q", alias)
|
|
}
|
|
}()
|
|
field.TsRank("query", alias)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRowNumberValidation tests ROW_NUMBER alias validation
|
|
func TestRowNumberValidation(t *testing.T) {
|
|
field := Field("id")
|
|
|
|
t.Run("valid_alias", func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("RowNumber should not panic for valid alias: %v", r)
|
|
}
|
|
}()
|
|
result := field.RowNumber("row_num")
|
|
if !strings.Contains(string(result), "row_num") {
|
|
t.Errorf("RowNumber result should contain alias")
|
|
}
|
|
})
|
|
|
|
t.Run("invalid_alias", func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("RowNumber should panic for invalid alias")
|
|
}
|
|
}()
|
|
field.RowNumber("'; DROP TABLE")
|
|
})
|
|
}
|
|
|
|
// TestSQLInjectionAttackVectors tests common SQL injection patterns
|
|
func TestSQLInjectionAttackVectors(t *testing.T) {
|
|
attacks := []string{
|
|
"'; DROP TABLE users; --",
|
|
"' OR '1'='1",
|
|
"'; DELETE FROM users WHERE '1'='1",
|
|
"1'; UPDATE users SET password='hacked'; --",
|
|
"admin'--",
|
|
"' UNION SELECT * FROM passwords--",
|
|
}
|
|
|
|
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
|
|
field := Field("created_at")
|
|
for _, attack := range attacks {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
|
|
}
|
|
}()
|
|
field.DateTrunc("day", attack)
|
|
}()
|
|
}
|
|
})
|
|
|
|
t.Run("TsRank_alias_protection", func(t *testing.T) {
|
|
field := Field("search_vector")
|
|
for _, attack := range attacks {
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("TsRank should prevent SQL injection: %q", attack)
|
|
}
|
|
}()
|
|
field.TsRank("query", attack)
|
|
}()
|
|
}
|
|
})
|
|
|
|
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
|
|
for _, attack := range attacks {
|
|
result := ConcatWs(attack, Field("col1"))
|
|
// Check that quotes are escaped (doubled)
|
|
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
|
|
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestFieldName tests Field.Name() method
|
|
func TestFieldName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
field Field
|
|
expected string
|
|
}{
|
|
{
|
|
name: "field with table prefix",
|
|
field: Field("users.email"),
|
|
expected: "email",
|
|
},
|
|
{
|
|
name: "field with schema and table prefix",
|
|
field: Field("public.users.email"),
|
|
expected: "email",
|
|
},
|
|
{
|
|
name: "field without dot",
|
|
field: Field("email"),
|
|
expected: "email",
|
|
},
|
|
{
|
|
name: "field with multiple dots",
|
|
field: Field("schema.table.column"),
|
|
expected: "column",
|
|
},
|
|
{
|
|
name: "aggregate function",
|
|
field: Field("COUNT(users.id)"),
|
|
expected: "id)",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := tt.field.Name()
|
|
if result != tt.expected {
|
|
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|