perf enhancement

This commit is contained in:
2025-11-16 16:21:35 +05:30
parent 29cddb6389
commit 1d9d9d9308
9 changed files with 917 additions and 35 deletions

View File

@@ -102,29 +102,29 @@ func TestConcatWsSQLInjectionPrevention(t *testing.T) {
func TestStringAggSQLInjectionPrevention(t *testing.T) {
tests := []struct {
name string
exp string
field Field
sep string
contains string
}{
{
name: "safe parameters",
exp: "column_name",
field: Field("column_name"),
sep: ", ",
contains: "string_agg(column_name,', ')",
},
{
name: "escaped quotes in separator",
sep: "'; DROP TABLE users; --",
exp: "col",
field: Field("col"),
contains: "string_agg(col,'''; DROP TABLE users; --')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StringAgg(tt.exp, tt.sep)
result := StringAgg(tt.field, tt.sep)
if !strings.Contains(string(result), tt.contains) {
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
}
})
}
@@ -132,7 +132,7 @@ func TestStringAggSQLInjectionPrevention(t *testing.T) {
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
result := StringAggCast("column", "'; DROP TABLE")
result := StringAggCast(Field("column"), "'; DROP TABLE")
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
if string(result) != expected {
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
@@ -336,3 +336,47 @@ func TestSQLInjectionAttackVectors(t *testing.T) {
}
})
}
// TestFieldName tests Field.Name() method
func TestFieldName(t *testing.T) {
tests := []struct {
name string
field Field
expected string
}{
{
name: "field with table prefix",
field: Field("users.email"),
expected: "email",
},
{
name: "field with schema and table prefix",
field: Field("public.users.email"),
expected: "email",
},
{
name: "field without dot",
field: Field("email"),
expected: "email",
},
{
name: "field with multiple dots",
field: Field("schema.table.column"),
expected: "column",
},
{
name: "aggregate function",
field: Field("COUNT(users.id)"),
expected: "id)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.field.Name()
if result != tt.expected {
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
}
})
}
}