package pgm import ( "strings" "testing" ) // TestValidateSQLIdentifier tests SQL identifier validation func TestValidateSQLIdentifier(t *testing.T) { tests := []struct { name string input string wantErr bool }{ {"valid simple", "column_name", false}, {"valid with numbers", "column123", false}, {"valid underscore", "_private", false}, {"valid mixed", "my_Column_123", false}, {"invalid with space", "column name", true}, {"invalid with dash", "column-name", true}, {"invalid with dot", "table.column", true}, {"invalid with quote", "column'name", true}, {"invalid with semicolon", "column;DROP", true}, {"invalid starts with number", "123column", true}, {"invalid SQL keyword injection", "col); DROP TABLE users; --", true}, {"empty string", "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateSQLIdentifier(tt.input) if (err != nil) != tt.wantErr { t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) } }) } } // TestEscapeSQLString tests SQL string escaping func TestEscapeSQLString(t *testing.T) { tests := []struct { name string input string expected string }{ {"no quotes", "hello", "hello"}, {"single quote", "hello'world", "hello''world"}, {"multiple quotes", "it's a 'test'", "it''s a ''test''"}, {"only quote", "'", "''"}, {"empty string", "", ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := escapeSQLString(tt.input) if result != tt.expected { t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected) } }) } } // TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes func TestConcatWsSQLInjectionPrevention(t *testing.T) { tests := []struct { name string sep string fields []Field contains string }{ { name: "safe separator", sep: ", ", fields: []Field{"col1", "col2"}, contains: "concat_ws(', ',col1, col2)", }, { name: "escaped quotes", sep: "', (SELECT password FROM users), '", fields: []Field{"col1"}, contains: "concat_ws(''', (SELECT password FROM users), ''',col1)", }, { name: "single quote", sep: "'", fields: []Field{"col1"}, contains: "concat_ws('''',col1)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := ConcatWs(tt.sep, tt.fields...) if !strings.Contains(string(result), tt.contains) { t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains) } }) } } // TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes func TestStringAggSQLInjectionPrevention(t *testing.T) { tests := []struct { name string field Field sep string contains string }{ { name: "safe parameters", field: Field("column_name"), sep: ", ", contains: "string_agg(column_name,', ')", }, { name: "escaped quotes in separator", sep: "'; DROP TABLE users; --", field: Field("col"), contains: "string_agg(col,'''; DROP TABLE users; --')", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := StringAgg(tt.field, tt.sep) if !strings.Contains(string(result), tt.contains) { t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains) } }) } } // TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes func TestStringAggCastSQLInjectionPrevention(t *testing.T) { result := StringAggCast(Field("column"), "'; DROP TABLE") expected := "string_agg(cast(column as varchar),'''; DROP TABLE')" if string(result) != expected { t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected) } } // TestDateTruncValidation tests DateTrunc level validation func TestDateTruncValidation(t *testing.T) { field := Field("created_at") validLevels := []string{ "microseconds", "milliseconds", "second", "minute", "hour", "day", "week", "month", "quarter", "year", "decade", "century", "millennium", } for _, level := range validLevels { t.Run("valid_"+level, func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("DateTrunc(%q) should not panic for valid level", level) } }() result := field.DateTrunc(level, "truncated") if !strings.Contains(string(result), level) { t.Errorf("DateTrunc result should contain level %q", level) } }) } // Test case-insensitive t.Run("case_insensitive", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("DateTrunc should accept uppercase level") } }() result := field.DateTrunc("MONTH", "truncated") if !strings.Contains(strings.ToLower(string(result)), "month") { t.Errorf("DateTrunc should normalize case") } }) } // TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level func TestDateTruncInvalidLevel(t *testing.T) { field := Field("created_at") invalidLevels := []string{ "invalid", "'; DROP TABLE users; --", "day; DELETE FROM", "", } for _, level := range invalidLevels { t.Run("invalid_"+level, func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level) } }() field.DateTrunc(level, "truncated") }) } } // TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias func TestDateTruncInvalidAlias(t *testing.T) { field := Field("created_at") invalidAliases := []string{ "alias name", "alias-name", "'; DROP TABLE", "123alias", "alias.name", } for _, alias := range invalidAliases { t.Run("invalid_alias_"+alias, func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias) } }() field.DateTrunc("day", alias) }) } } // TestTsRankValidation tests TsRank alias validation func TestTsRankValidation(t *testing.T) { field := Field("search_vector") validAliases := []string{"rank", "score", "relevance_123", "_rank"} for _, alias := range validAliases { t.Run("valid_"+alias, func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("TsRank should not panic for valid alias %q", alias) } }() result := field.TsRank("query", alias) if !strings.Contains(string(result), alias) { t.Errorf("TsRank result should contain alias %q", alias) } }) } invalidAliases := []string{ "'; DROP TABLE", "rank name", "rank-name", "123rank", } for _, alias := range invalidAliases { t.Run("invalid_"+alias, func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("TsRank should panic for invalid alias %q", alias) } }() field.TsRank("query", alias) }) } } // TestRowNumberValidation tests ROW_NUMBER alias validation func TestRowNumberValidation(t *testing.T) { field := Field("id") t.Run("valid_alias", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("RowNumber should not panic for valid alias: %v", r) } }() result := field.RowNumber("row_num") if !strings.Contains(string(result), "row_num") { t.Errorf("RowNumber result should contain alias") } }) t.Run("invalid_alias", func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("RowNumber should panic for invalid alias") } }() field.RowNumber("'; DROP TABLE") }) } // TestSQLInjectionAttackVectors tests common SQL injection patterns func TestSQLInjectionAttackVectors(t *testing.T) { attacks := []string{ "'; DROP TABLE users; --", "' OR '1'='1", "'; DELETE FROM users WHERE '1'='1", "1'; UPDATE users SET password='hacked'; --", "admin'--", "' UNION SELECT * FROM passwords--", } t.Run("DateTrunc_alias_protection", func(t *testing.T) { field := Field("created_at") for _, attack := range attacks { func() { defer func() { if r := recover(); r == nil { t.Errorf("DateTrunc should prevent SQL injection: %q", attack) } }() field.DateTrunc("day", attack) }() } }) t.Run("TsRank_alias_protection", func(t *testing.T) { field := Field("search_vector") for _, attack := range attacks { func() { defer func() { if r := recover(); r == nil { t.Errorf("TsRank should prevent SQL injection: %q", attack) } }() field.TsRank("query", attack) }() } }) t.Run("ConcatWs_separator_escaping", func(t *testing.T) { for _, attack := range attacks { result := ConcatWs(attack, Field("col1")) // Check that quotes are escaped (doubled) if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") { t.Errorf("ConcatWs should escape quotes in attack: %q", attack) } } }) } // TestFieldName tests Field.Name() method func TestFieldName(t *testing.T) { tests := []struct { name string field Field expected string }{ { name: "field with table prefix", field: Field("users.email"), expected: "email", }, { name: "field with schema and table prefix", field: Field("public.users.email"), expected: "email", }, { name: "field without dot", field: Field("email"), expected: "email", }, { name: "field with multiple dots", field: Field("schema.table.column"), expected: "column", }, { name: "aggregate function", field: Field("COUNT(users.id)"), expected: "id)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.field.Name() if result != tt.expected { t.Errorf("Field.Name() = %q, want %q", result, tt.expected) } }) } }