perf enhancement
This commit is contained in:
@@ -102,29 +102,29 @@ func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exp string
|
||||
field Field
|
||||
sep string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe parameters",
|
||||
exp: "column_name",
|
||||
field: Field("column_name"),
|
||||
sep: ", ",
|
||||
contains: "string_agg(column_name,', ')",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes in separator",
|
||||
sep: "'; DROP TABLE users; --",
|
||||
exp: "col",
|
||||
field: Field("col"),
|
||||
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StringAgg(tt.exp, tt.sep)
|
||||
result := StringAgg(tt.field, tt.sep)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
|
||||
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
|
||||
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||
result := StringAggCast("column", "'; DROP TABLE")
|
||||
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
||||
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||
if string(result) != expected {
|
||||
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||
@@ -336,3 +336,47 @@ func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFieldName tests Field.Name() method
|
||||
func TestFieldName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field Field
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "field with table prefix",
|
||||
field: Field("users.email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field with schema and table prefix",
|
||||
field: Field("public.users.email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field without dot",
|
||||
field: Field("email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field with multiple dots",
|
||||
field: Field("schema.table.column"),
|
||||
expected: "column",
|
||||
},
|
||||
{
|
||||
name: "aggregate function",
|
||||
field: Field("COUNT(users.id)"),
|
||||
expected: "id)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.field.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user