// Patial Tech. // Author, Ankit Patial package pgm import ( "context" "errors" "strings" "testing" "time" "github.com/jackc/pgx/v5" ) // TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized func TestGetPoolNotInitialized(t *testing.T) { // Save current pool state currentPool := poolPGX.Load() // Clear the pool to simulate uninitialized state poolPGX.Store(nil) // Restore pool after test defer func() { poolPGX.Store(currentPool) }() // Verify panic occurs defer func() { if r := recover(); r == nil { t.Error("GetPool should panic when pool not initialized") } else { // Check panic message msg, ok := r.(string) if !ok || !strings.Contains(msg, "not initialized") { t.Errorf("Expected panic message about initialization, got: %v", r) } } }() GetPool() // Should panic } // TestConfigValidation tests that InitPool validates configuration func TestConfigValidation(t *testing.T) { tests := []struct { name string config Config wantPanic bool panicMsg string }{ { name: "empty connection string", config: Config{ ConnString: "", }, wantPanic: true, panicMsg: "connection string is empty", }, { name: "MinConns greater than MaxConns", config: Config{ ConnString: "postgres://localhost/test", MaxConns: 5, MinConns: 10, }, wantPanic: true, panicMsg: "MinConns", }, { name: "negative MaxConns", config: Config{ ConnString: "postgres://localhost/test", MaxConns: -1, }, wantPanic: true, panicMsg: "negative", }, { name: "negative MinConns", config: Config{ ConnString: "postgres://localhost/test", MinConns: -1, }, wantPanic: true, panicMsg: "negative", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { defer func() { r := recover() if tt.wantPanic && r == nil { t.Errorf("InitPool should panic for %s", tt.name) } if !tt.wantPanic && r != nil { t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r) } if tt.wantPanic && r != nil { msg := "" switch v := r.(type) { case string: msg = v case error: msg = v.Error() } if !strings.Contains(msg, tt.panicMsg) { t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg) } } }() InitPool(tt.config) }) } } // TestClosePoolIdempotent verifies ClosePool can be called multiple times safely func TestClosePoolIdempotent(t *testing.T) { // Save current pool currentPool := poolPGX.Load() defer func() { poolPGX.Store(currentPool) }() // Set to nil poolPGX.Store(nil) // Should not panic when called on nil pool ClosePool() ClosePool() ClosePool() // Should work fine - no panic expected } // TestIsNotFound tests the error checking utility func TestIsNotFound(t *testing.T) { // Test with pgx.ErrNoRows if !IsNotFound(pgx.ErrNoRows) { t.Error("IsNotFound should return true for pgx.ErrNoRows") } // Test with other errors otherErr := errors.New("some other error") if IsNotFound(otherErr) { t.Error("IsNotFound should return false for non-ErrNoRows errors") } // Test with nil if IsNotFound(nil) { t.Error("IsNotFound should return false for nil") } } // TestPgTime tests PostgreSQL timestamp conversion func TestPgTime(t *testing.T) { now := time.Now() pgTime := PgTime(now) if !pgTime.Valid { t.Error("PgTime should return valid timestamp") } if !pgTime.Time.Equal(now) { t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now) } } // TestPgTimeNow tests current time conversion func TestPgTimeNow(t *testing.T) { before := time.Now() pgTime := PgTimeNow() after := time.Now() if !pgTime.Valid { t.Error("PgTimeNow should return valid timestamp") } // Check that time is between before and after if pgTime.Time.Before(before) || pgTime.Time.After(after) { t.Error("PgTimeNow should return current time") } } // TestTsQueryFunctions tests full-text search query builders func TestTsQueryFunctions(t *testing.T) { tests := []struct { name string fn func(string) string input string expected string }{ { name: "TsAndQuery", fn: TsAndQuery, input: "hello world", expected: "hello & world", }, { name: "TsPrefixAndQuery", fn: TsPrefixAndQuery, input: "hello world", expected: "hello:* & world:*", }, { name: "TsOrQuery", fn: TsOrQuery, input: "hello world", expected: "hello | world", }, { name: "TsPrefixOrQuery", fn: TsPrefixOrQuery, input: "hello world", expected: "hello:* | world:*", }, { name: "TsAndQuery with multiple words", fn: TsAndQuery, input: "one two three", expected: "one & two & three", }, { name: "TsOrQuery with multiple words", fn: TsOrQuery, input: "one two three", expected: "one | two | three", }, { name: "TsAndQuery with extra spaces", fn: TsAndQuery, input: "hello world", expected: "hello & world", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.fn(tt.input) if result != tt.expected { t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected) } }) } } // TestFieldsWithSuffix tests the internal suffix function func TestFieldsWithSuffix(t *testing.T) { result := fieldsWithSufix("hello world", ":*") expected := []string{"hello:*", "world:*"} if len(result) != len(expected) { t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected)) } for i, v := range result { if v != expected[i] { t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i]) } } } // TestBeginTxErrorWrapping tests that BeginTx preserves error context func TestBeginTxErrorWrapping(t *testing.T) { // This test verifies basic behavior without requiring a database connection // Actual error wrapping would require a failing database connection // Skip if no pool initialized (unit test environment) if poolPGX.Load() == nil { t.Skip("Skipping BeginTx test - requires initialized pool") } // Just verify the function accepts a context // In a real scenario with a cancelled context, it would return an error ctx := context.Background() // We can't fully test without a real DB connection, so just document behavior _, err := BeginTx(ctx) if err != nil { // Error is expected in test environment without valid DB t.Logf("BeginTx returned error (expected in test env): %v", err) } } // TestContextCancellation tests that cancelled context is handled func TestContextCancellation(t *testing.T) { // Create an already-cancelled context ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately // Note: This test would require an actual database connection to verify // that the context cancellation is properly propagated. We're testing // that the context is accepted by the function signature. // The actual behavior would be tested in integration tests // For now, we just verify the function doesn't panic with cancelled context // Skip if no pool initialized (unit test environment) if poolPGX.Load() == nil { t.Skip("Skipping context cancellation test - requires initialized pool") } // This would return an error about cancelled context in real scenario _, err := BeginTx(ctx) if err == nil { t.Log("BeginTx with cancelled context - error expected in real scenario") } } // TestErrorTypes verifies exported error variables func TestErrorTypes(t *testing.T) { if ErrConnStringMissing == nil { t.Error("ErrConnStringMissing should be defined") } if ErrInitTX == nil { t.Error("ErrInitTX should be defined") } if ErrCommitTX == nil { t.Error("ErrCommitTX should be defined") } if ErrNoRows == nil { t.Error("ErrNoRows should be defined") } // Check error messages are descriptive if !strings.Contains(ErrConnStringMissing.Error(), "connection string") { t.Error("ErrConnStringMissing should mention connection string") } } // TestStringPoolGetPut tests the string builder pool func TestStringPoolGetPut(t *testing.T) { // Get a string builder from pool sb1 := getSB() if sb1 == nil { t.Fatal("getSB should return non-nil string builder") } // Use it sb1.WriteString("test") if sb1.String() != "test" { t.Error("String builder should work normally") } // Put it back putSB(sb1) // Get another one - should be reset sb2 := getSB() if sb2 == nil { t.Fatal("getSB should return non-nil string builder") } // Should be empty after reset if sb2.Len() != 0 { t.Error("String builder from pool should be reset") } putSB(sb2) } // TestConditionActionTypes tests condition action enum func TestConditionActionTypes(t *testing.T) { // Verify enum values are distinct actions := []CondAction{ CondActionNothing, CondActionNeedToClose, CondActionSubQuery, } seen := make(map[CondAction]bool) for _, action := range actions { if seen[action] { t.Errorf("Duplicate CondAction value: %v", action) } seen[action] = true } } // TestSelectQueryBuilderBasics tests basic query building func TestSelectQueryBuilderBasics(t *testing.T) { // Create a test table testTable := &Table{Name: "users", FieldCount: 10} idField := Field("users.id") emailField := Field("users.email") // Test basic select qry := testTable.Select(idField, emailField) sql := qry.String() if !strings.Contains(sql, "SELECT") { t.Error("Query should contain SELECT") } if !strings.Contains(sql, "users.id") { t.Error("Query should contain users.id field") } if !strings.Contains(sql, "users.email") { t.Error("Query should contain users.email field") } if !strings.Contains(sql, "FROM users") { t.Error("Query should contain FROM users") } } // TestWhereConditionAccumulation tests that WHERE conditions accumulate func TestWhereConditionAccumulation(t *testing.T) { testTable := &Table{Name: "users", FieldCount: 10} idField := Field("users.id") statusField := Field("users.status") // Create a query and add multiple WHERE conditions qry := testTable.Select(idField). Where(idField.Eq(1)). Where(statusField.Eq("active")) sql := qry.String() // Both conditions should be present if !strings.Contains(sql, "WHERE") { t.Error("Query should contain WHERE clause") } // Should have multiple conditions (this demonstrates accumulation) if strings.Count(sql, "$") < 2 { t.Error("Query should have multiple parameterized values") } } // TestInsertQueryBuilder tests insert query building func TestInsertQueryBuilder(t *testing.T) { testTable := &Table{Name: "users", FieldCount: 10} qry := testTable.Insert() sql := qry.String() if !strings.Contains(sql, "INSERT INTO users") { t.Error("Query should contain INSERT INTO users") } } // TestUpdateQueryBuilder tests update query building func TestUpdateQueryBuilder(t *testing.T) { testTable := &Table{Name: "users", FieldCount: 10} idField := Field("users.id") qry := testTable.Update(). Where(idField.Eq(1)) sql := qry.String() if !strings.Contains(sql, "UPDATE users") { t.Error("Query should contain UPDATE users") } if !strings.Contains(sql, "WHERE") { t.Error("Query should contain WHERE clause") } } // TestDeleteQueryBuilder tests delete query building func TestDeleteQueryBuilder(t *testing.T) { testTable := &Table{Name: "users", FieldCount: 10} idField := Field("users.id") qry := testTable.Delete().Where(idField.Eq(1)) sql := qry.String() if !strings.Contains(sql, "DELETE FROM users") { t.Error("Query should contain DELETE FROM users") } if !strings.Contains(sql, "WHERE") { t.Error("Query should contain WHERE clause") } } // TestFieldConditioners tests various field condition methods func TestFieldConditioners(t *testing.T) { field := Field("users.age") tests := []struct { name string condition Conditioner checkFunc func(string) bool }{ { name: "Eq", condition: field.Eq(25), checkFunc: func(s string) bool { return strings.Contains(s, " = $") }, }, { name: "NotEq", condition: field.NotEq(25), checkFunc: func(s string) bool { return strings.Contains(s, " != $") }, }, { name: "Gt", condition: field.Gt(25), checkFunc: func(s string) bool { return strings.Contains(s, " > $") }, }, { name: "Lt", condition: field.Lt(25), checkFunc: func(s string) bool { return strings.Contains(s, " < $") }, }, { name: "Gte", condition: field.Gte(25), checkFunc: func(s string) bool { return strings.Contains(s, " >= $") }, }, { name: "Lte", condition: field.Lte(25), checkFunc: func(s string) bool { return strings.Contains(s, " <= $") }, }, { name: "IsNull", condition: field.IsNull(), checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") }, }, { name: "IsNotNull", condition: field.IsNotNull(), checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Build a simple query to see the condition testTable := &Table{Name: "users", FieldCount: 10} qry := testTable.Select(field).Where(tt.condition) sql := qry.String() if !tt.checkFunc(sql) { t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql) } }) } } // TestFieldFunctions tests SQL function wrappers func TestFieldFunctions(t *testing.T) { field := Field("users.count") tests := []struct { name string result Field contains string }{ { name: "Count", result: field.Count(), contains: "COUNT(", }, { name: "Sum", result: field.Sum(), contains: "SUM(", }, { name: "Avg", result: field.Avg(), contains: "AVG(", }, { name: "Max", result: field.Max(), contains: "MAX(", }, { name: "Min", result: field.Min(), contains: "Min(", }, { name: "Lower", result: field.Lower(), contains: "LOWER(", }, { name: "Upper", result: field.Upper(), contains: "UPPER(", }, { name: "Trim", result: field.Trim(), contains: "TRIM(", }, { name: "Asc", result: field.Asc(), contains: " ASC", }, { name: "Desc", result: field.Desc(), contains: " DESC", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { resultStr := string(tt.result) if !strings.Contains(resultStr, tt.contains) { t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr) } }) } } // TestInsertQueryValidation tests that Insert queries validate fields are set func TestInsertQueryValidation(t *testing.T) { testTable := &Table{Name: "users", FieldCount: 10} idField := Field("users.id") t.Run("Exec without fields", func(t *testing.T) { qry := testTable.Insert() err := qry.Exec(context.Background()) if err == nil { t.Error("Insert.Exec() should return error when no fields set") } if !strings.Contains(err.Error(), "no fields to insert") { t.Errorf("Expected error about no fields, got: %v", err) } }) t.Run("ExecTx without fields", func(t *testing.T) { qry := testTable.Insert() // We can't test ExecTx without a real transaction, but we can test the error is defined err := qry.ExecTx(context.Background(), nil) if err == nil { t.Error("Insert.ExecTx() should return error when no fields set") } if !strings.Contains(err.Error(), "no fields to insert") { t.Errorf("Expected error about no fields, got: %v", err) } }) t.Run("First without fields", func(t *testing.T) { qry := testTable.Insert().Returning(idField) var id int err := qry.First(context.Background(), &id) if err == nil { t.Error("Insert.First() should return error when no fields set") } if !strings.Contains(err.Error(), "no fields to insert") { t.Errorf("Expected error about no fields, got: %v", err) } }) t.Run("FirstTx without fields", func(t *testing.T) { qry := testTable.Insert().Returning(idField) var id int err := qry.FirstTx(context.Background(), nil, &id) if err == nil { t.Error("Insert.FirstTx() should return error when no fields set") } if !strings.Contains(err.Error(), "no fields to insert") { t.Errorf("Expected error about no fields, got: %v", err) } }) }