Files
pgm/pgm_test.go

671 lines
16 KiB
Go

// Patial Tech.
// Author, Ankit Patial
package pgm
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5"
)
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
func TestGetPoolNotInitialized(t *testing.T) {
// Save current pool state
currentPool := poolPGX.Load()
// Clear the pool to simulate uninitialized state
poolPGX.Store(nil)
// Restore pool after test
defer func() {
poolPGX.Store(currentPool)
}()
// Verify panic occurs
defer func() {
if r := recover(); r == nil {
t.Error("GetPool should panic when pool not initialized")
} else {
// Check panic message
msg, ok := r.(string)
if !ok || !strings.Contains(msg, "not initialized") {
t.Errorf("Expected panic message about initialization, got: %v", r)
}
}
}()
GetPool() // Should panic
}
// TestConfigValidation tests that InitPool validates configuration
func TestConfigValidation(t *testing.T) {
tests := []struct {
name string
config Config
wantPanic bool
panicMsg string
}{
{
name: "empty connection string",
config: Config{
ConnString: "",
},
wantPanic: true,
panicMsg: "connection string is empty",
},
{
name: "MinConns greater than MaxConns",
config: Config{
ConnString: "postgres://localhost/test",
MaxConns: 5,
MinConns: 10,
},
wantPanic: true,
panicMsg: "MinConns",
},
{
name: "negative MaxConns",
config: Config{
ConnString: "postgres://localhost/test",
MaxConns: -1,
},
wantPanic: true,
panicMsg: "negative",
},
{
name: "negative MinConns",
config: Config{
ConnString: "postgres://localhost/test",
MinConns: -1,
},
wantPanic: true,
panicMsg: "negative",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
r := recover()
if tt.wantPanic && r == nil {
t.Errorf("InitPool should panic for %s", tt.name)
}
if !tt.wantPanic && r != nil {
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
}
if tt.wantPanic && r != nil {
msg := ""
switch v := r.(type) {
case string:
msg = v
case error:
msg = v.Error()
}
if !strings.Contains(msg, tt.panicMsg) {
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
}
}
}()
InitPool(tt.config)
})
}
}
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
func TestClosePoolIdempotent(t *testing.T) {
// Save current pool
currentPool := poolPGX.Load()
defer func() {
poolPGX.Store(currentPool)
}()
// Set to nil
poolPGX.Store(nil)
// Should not panic when called on nil pool
ClosePool()
ClosePool()
ClosePool()
// Should work fine - no panic expected
}
// TestIsNotFound tests the error checking utility
func TestIsNotFound(t *testing.T) {
// Test with pgx.ErrNoRows
if !IsNotFound(pgx.ErrNoRows) {
t.Error("IsNotFound should return true for pgx.ErrNoRows")
}
// Test with other errors
otherErr := errors.New("some other error")
if IsNotFound(otherErr) {
t.Error("IsNotFound should return false for non-ErrNoRows errors")
}
// Test with nil
if IsNotFound(nil) {
t.Error("IsNotFound should return false for nil")
}
}
// TestPgTime tests PostgreSQL timestamp conversion
func TestPgTime(t *testing.T) {
now := time.Now()
pgTime := PgTime(now)
if !pgTime.Valid {
t.Error("PgTime should return valid timestamp")
}
if !pgTime.Time.Equal(now) {
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
}
}
// TestPgTimeNow tests current time conversion
func TestPgTimeNow(t *testing.T) {
before := time.Now()
pgTime := PgTimeNow()
after := time.Now()
if !pgTime.Valid {
t.Error("PgTimeNow should return valid timestamp")
}
// Check that time is between before and after
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
t.Error("PgTimeNow should return current time")
}
}
// TestTsQueryFunctions tests full-text search query builders
func TestTsQueryFunctions(t *testing.T) {
tests := []struct {
name string
fn func(string) string
input string
expected string
}{
{
name: "TsAndQuery",
fn: TsAndQuery,
input: "hello world",
expected: "hello & world",
},
{
name: "TsPrefixAndQuery",
fn: TsPrefixAndQuery,
input: "hello world",
expected: "hello:* & world:*",
},
{
name: "TsOrQuery",
fn: TsOrQuery,
input: "hello world",
expected: "hello | world",
},
{
name: "TsPrefixOrQuery",
fn: TsPrefixOrQuery,
input: "hello world",
expected: "hello:* | world:*",
},
{
name: "TsAndQuery with multiple words",
fn: TsAndQuery,
input: "one two three",
expected: "one & two & three",
},
{
name: "TsOrQuery with multiple words",
fn: TsOrQuery,
input: "one two three",
expected: "one | two | three",
},
{
name: "TsAndQuery with extra spaces",
fn: TsAndQuery,
input: "hello world",
expected: "hello & world",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.fn(tt.input)
if result != tt.expected {
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
}
})
}
}
// TestFieldsWithSuffix tests the internal suffix function
func TestFieldsWithSuffix(t *testing.T) {
result := fieldsWithSufix("hello world", ":*")
expected := []string{"hello:*", "world:*"}
if len(result) != len(expected) {
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
}
for i, v := range result {
if v != expected[i] {
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
}
}
}
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
func TestBeginTxErrorWrapping(t *testing.T) {
// This test verifies basic behavior without requiring a database connection
// Actual error wrapping would require a failing database connection
// Skip if no pool initialized (unit test environment)
if poolPGX.Load() == nil {
t.Skip("Skipping BeginTx test - requires initialized pool")
}
// Just verify the function accepts a context
// In a real scenario with a cancelled context, it would return an error
ctx := context.Background()
// We can't fully test without a real DB connection, so just document behavior
_, err := BeginTx(ctx)
if err != nil {
// Error is expected in test environment without valid DB
t.Logf("BeginTx returned error (expected in test env): %v", err)
}
}
// TestContextCancellation tests that cancelled context is handled
func TestContextCancellation(t *testing.T) {
// Create an already-cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
// Note: This test would require an actual database connection to verify
// that the context cancellation is properly propagated. We're testing
// that the context is accepted by the function signature.
// The actual behavior would be tested in integration tests
// For now, we just verify the function doesn't panic with cancelled context
// Skip if no pool initialized (unit test environment)
if poolPGX.Load() == nil {
t.Skip("Skipping context cancellation test - requires initialized pool")
}
// This would return an error about cancelled context in real scenario
_, err := BeginTx(ctx)
if err == nil {
t.Log("BeginTx with cancelled context - error expected in real scenario")
}
}
// TestErrorTypes verifies exported error variables
func TestErrorTypes(t *testing.T) {
if ErrConnStringMissing == nil {
t.Error("ErrConnStringMissing should be defined")
}
if ErrInitTX == nil {
t.Error("ErrInitTX should be defined")
}
if ErrCommitTX == nil {
t.Error("ErrCommitTX should be defined")
}
if ErrNoRows == nil {
t.Error("ErrNoRows should be defined")
}
// Check error messages are descriptive
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
t.Error("ErrConnStringMissing should mention connection string")
}
}
// TestStringPoolGetPut tests the string builder pool
func TestStringPoolGetPut(t *testing.T) {
// Get a string builder from pool
sb1 := getSB()
if sb1 == nil {
t.Fatal("getSB should return non-nil string builder")
}
// Use it
sb1.WriteString("test")
if sb1.String() != "test" {
t.Error("String builder should work normally")
}
// Put it back
putSB(sb1)
// Get another one - should be reset
sb2 := getSB()
if sb2 == nil {
t.Fatal("getSB should return non-nil string builder")
}
// Should be empty after reset
if sb2.Len() != 0 {
t.Error("String builder from pool should be reset")
}
putSB(sb2)
}
// TestConditionActionTypes tests condition action enum
func TestConditionActionTypes(t *testing.T) {
// Verify enum values are distinct
actions := []CondAction{
CondActionNothing,
CondActionNeedToClose,
CondActionSubQuery,
}
seen := make(map[CondAction]bool)
for _, action := range actions {
if seen[action] {
t.Errorf("Duplicate CondAction value: %v", action)
}
seen[action] = true
}
}
// TestSelectQueryBuilderBasics tests basic query building
func TestSelectQueryBuilderBasics(t *testing.T) {
// Create a test table
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
emailField := Field("users.email")
// Test basic select
qry := testTable.Select(idField, emailField)
sql := qry.String()
if !strings.Contains(sql, "SELECT") {
t.Error("Query should contain SELECT")
}
if !strings.Contains(sql, "users.id") {
t.Error("Query should contain users.id field")
}
if !strings.Contains(sql, "users.email") {
t.Error("Query should contain users.email field")
}
if !strings.Contains(sql, "FROM users") {
t.Error("Query should contain FROM users")
}
}
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
func TestWhereConditionAccumulation(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
statusField := Field("users.status")
// Create a query and add multiple WHERE conditions
qry := testTable.Select(idField).
Where(idField.Eq(1)).
Where(statusField.Eq("active"))
sql := qry.String()
// Both conditions should be present
if !strings.Contains(sql, "WHERE") {
t.Error("Query should contain WHERE clause")
}
// Should have multiple conditions (this demonstrates accumulation)
if strings.Count(sql, "$") < 2 {
t.Error("Query should have multiple parameterized values")
}
}
// TestInsertQueryBuilder tests insert query building
func TestInsertQueryBuilder(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
qry := testTable.Insert()
sql := qry.String()
if !strings.Contains(sql, "INSERT INTO users") {
t.Error("Query should contain INSERT INTO users")
}
}
// TestUpdateQueryBuilder tests update query building
func TestUpdateQueryBuilder(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
qry := testTable.Update().
Where(idField.Eq(1))
sql := qry.String()
if !strings.Contains(sql, "UPDATE users") {
t.Error("Query should contain UPDATE users")
}
if !strings.Contains(sql, "WHERE") {
t.Error("Query should contain WHERE clause")
}
}
// TestDeleteQueryBuilder tests delete query building
func TestDeleteQueryBuilder(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
qry := testTable.Delete().Where(idField.Eq(1))
sql := qry.String()
if !strings.Contains(sql, "DELETE FROM users") {
t.Error("Query should contain DELETE FROM users")
}
if !strings.Contains(sql, "WHERE") {
t.Error("Query should contain WHERE clause")
}
}
// TestFieldConditioners tests various field condition methods
func TestFieldConditioners(t *testing.T) {
field := Field("users.age")
tests := []struct {
name string
condition Conditioner
checkFunc func(string) bool
}{
{
name: "Eq",
condition: field.Eq(25),
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
},
{
name: "NotEq",
condition: field.NotEq(25),
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
},
{
name: "Gt",
condition: field.Gt(25),
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
},
{
name: "Lt",
condition: field.Lt(25),
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
},
{
name: "Gte",
condition: field.Gte(25),
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
},
{
name: "Lte",
condition: field.Lte(25),
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
},
{
name: "IsNull",
condition: field.IsNull(),
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
},
{
name: "IsNotNull",
condition: field.IsNotNull(),
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Build a simple query to see the condition
testTable := &Table{Name: "users", FieldCount: 10}
qry := testTable.Select(field).Where(tt.condition)
sql := qry.String()
if !tt.checkFunc(sql) {
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
}
})
}
}
// TestFieldFunctions tests SQL function wrappers
func TestFieldFunctions(t *testing.T) {
field := Field("users.count")
tests := []struct {
name string
result Field
contains string
}{
{
name: "Count",
result: field.Count(),
contains: "COUNT(",
},
{
name: "Sum",
result: field.Sum(),
contains: "SUM(",
},
{
name: "Avg",
result: field.Avg(),
contains: "AVG(",
},
{
name: "Max",
result: field.Max(),
contains: "MAX(",
},
{
name: "Min",
result: field.Min(),
contains: "Min(",
},
{
name: "Lower",
result: field.Lower(),
contains: "LOWER(",
},
{
name: "Upper",
result: field.Upper(),
contains: "UPPER(",
},
{
name: "Trim",
result: field.Trim(),
contains: "TRIM(",
},
{
name: "Asc",
result: field.Asc(),
contains: " ASC",
},
{
name: "Desc",
result: field.Desc(),
contains: " DESC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resultStr := string(tt.result)
if !strings.Contains(resultStr, tt.contains) {
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
}
})
}
}
// TestInsertQueryValidation tests that Insert queries validate fields are set
func TestInsertQueryValidation(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
t.Run("Exec without fields", func(t *testing.T) {
qry := testTable.Insert()
err := qry.Exec(context.Background())
if err == nil {
t.Error("Insert.Exec() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
t.Run("ExecTx without fields", func(t *testing.T) {
qry := testTable.Insert()
// We can't test ExecTx without a real transaction, but we can test the error is defined
err := qry.ExecTx(context.Background(), nil)
if err == nil {
t.Error("Insert.ExecTx() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
t.Run("First without fields", func(t *testing.T) {
qry := testTable.Insert().Returning(idField)
var id int
err := qry.First(context.Background(), &id)
if err == nil {
t.Error("Insert.First() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
t.Run("FirstTx without fields", func(t *testing.T) {
qry := testTable.Insert().Returning(idField)
var id int
err := qry.FirstTx(context.Background(), nil, &id)
if err == nil {
t.Error("Insert.FirstTx() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
}