diff --git a/README.md b/README.md index 40157b2..0b33873 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,56 @@ func main() { } ``` +## Important: Query Builder Lifecycle + +⚠️ **Query builders are single-use and should not be reused:** + +```go +// ❌ WRONG - Don't reuse query builders +baseQuery := users.User.Select(users.ID, users.Email) +baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1) +baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2) +// Second query has BOTH WHERE clauses - incorrect behavior! + +// ✅ CORRECT - Create new query each time +users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1) +users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2) +``` + +**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up. + +### Thread Safety + +⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines: + +```go +// ✅ CORRECT - Each goroutine creates its own query +for i := 0; i < 10; i++ { + go func(id int) { + var email string + err := users.User.Select(users.Email). + Where(users.ID.Eq(id)). + First(ctx, &email) + // Process result... + }(i) +} + +// ❌ WRONG - Sharing query builder across goroutines +baseQuery := users.User.Select(users.Email) +for i := 0; i < 10; i++ { + go func(id int) { + var email string + baseQuery.Where(users.ID.Eq(id)).First(ctx, &email) + // RACE CONDITION! Multiple goroutines modifying shared state + }(i) +} +``` + +**Thread-Safe Components:** +- ✅ Connection Pool - Safe for concurrent use +- ✅ Table objects - Safe to share +- ❌ Query builders - Create new instance per goroutine + ## Usage Examples ### SELECT Queries diff --git a/pgm.go b/pgm.go index 1580553..18bb215 100644 --- a/pgm.go +++ b/pgm.go @@ -25,13 +25,15 @@ var ( ErrConnStringMissing = errors.New("connection string is empty") ) -// Errors +// Common errors returned by pgm operations var ( ErrInitTX = errors.New("failed to init db.tx") ErrCommitTX = errors.New("failed to commit db.tx") ErrNoRows = errors.New("no data found") ) +// Config holds the configuration for initializing the connection pool. +// All fields except ConnString are optional and will use pgx defaults if not set. type Config struct { ConnString string MaxConns int32 @@ -40,7 +42,17 @@ type Config struct { MaxConnIdleTime time.Duration } -// InitPool will create new pgxpool.Pool and will keep it for its working +// InitPool initializes the connection pool with the provided configuration. +// It validates the configuration and panics if invalid. +// This function should be called once at application startup. +// +// Example: +// +// pgm.InitPool(pgm.Config{ +// ConnString: "postgres://user:pass@localhost/dbname", +// MaxConns: 100, +// MinConns: 5, +// }) func InitPool(conf Config) { if conf.ConnString == "" { panic(ErrConnStringMissing) @@ -88,9 +100,15 @@ func InitPool(conf Config) { poolPGX.Store(p) } -// GetPool instance +// GetPool returns the initialized connection pool instance. +// It panics with a descriptive message if InitPool() has not been called. +// This is a fail-fast approach to catch programming errors early. func GetPool() *pgxpool.Pool { - return poolPGX.Load() + p := poolPGX.Load() + if p == nil { + panic("pgm: connection pool not initialized, call InitPool() first") + } + return p } // ClosePool closes the connection pool gracefully. @@ -102,7 +120,21 @@ func ClosePool() { } } -// BeginTx begins a pgx poll transaction +// BeginTx begins a new database transaction from the connection pool. +// Returns an error if the transaction cannot be started. +// Remember to commit or rollback the transaction when done. +// +// Example: +// +// tx, err := pgm.BeginTx(ctx) +// if err != nil { +// return err +// } +// defer tx.Rollback(ctx) // rollback on error +// +// // ... do work ... +// +// return tx.Commit(ctx) func BeginTx(ctx context.Context) (pgx.Tx, error) { tx, err := poolPGX.Load().Begin(ctx) if err != nil { @@ -113,32 +145,43 @@ func BeginTx(ctx context.Context) (pgx.Tx, error) { return tx, nil } -// IsNotFound error check +// IsNotFound checks if an error is a "no rows" error from pgx. +// Returns true if the error indicates no rows were found in a query result. func IsNotFound(err error) bool { return errors.Is(err, pgx.ErrNoRows) } -// PgTime as in UTC +// PgTime converts a Go time.Time to PostgreSQL timestamptz type. +// The time is stored as-is (preserves timezone information). func PgTime(t time.Time) pgtype.Timestamptz { return pgtype.Timestamptz{Time: t, Valid: true} } +// PgTimeNow returns the current time as PostgreSQL timestamptz type. func PgTimeNow() pgtype.Timestamptz { return pgtype.Timestamptz{Time: time.Now(), Valid: true} } +// TsAndQuery converts a text search query to use AND operator between terms. +// Example: "hello world" becomes "hello & world" func TsAndQuery(q string) string { return strings.Join(strings.Fields(q), " & ") } +// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching. +// Example: "hello world" becomes "hello:* & world:*" func TsPrefixAndQuery(q string) string { return strings.Join(fieldsWithSufix(q, ":*"), " & ") } +// TsOrQuery converts a text search query to use OR operator between terms. +// Example: "hello world" becomes "hello | world" func TsOrQuery(q string) string { return strings.Join(strings.Fields(q), " | ") } +// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching. +// Example: "hello world" becomes "hello:* | world:*" func TsPrefixOrQuery(q string) string { return strings.Join(fieldsWithSufix(q, ":*"), " | ") } diff --git a/pgm_field.go b/pgm_field.go index 597fa09..2ee690b 100644 --- a/pgm_field.go +++ b/pgm_field.go @@ -45,7 +45,12 @@ func escapeSQLString(s string) string { } func (f Field) Name() string { - return strings.Split(string(f), ".")[1] + s := string(f) + idx := strings.LastIndexByte(s, '.') + if idx == -1 { + return s // Return as-is if no dot + } + return s[idx+1:] // Return part after last dot } func (f Field) String() string { @@ -66,19 +71,19 @@ func ConcatWs(sep string, fields ...Field) Field { } // StringAgg creates a STRING_AGG SQL function. -// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL. +// SECURITY: The field parameter provides type safety by accepting only Field type. // The sep parameter should only be a constant string. Single quotes will be escaped. -func StringAgg(exp, sep string) Field { +func StringAgg(field Field, sep string) Field { escapedSep := escapeSQLString(sep) - return Field("string_agg(" + exp + ",'" + escapedSep + "')") + return Field("string_agg(" + field.String() + ",'" + escapedSep + "')") } // StringAggCast creates a STRING_AGG SQL function with cast to varchar. -// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL. +// SECURITY: The field parameter provides type safety by accepting only Field type. // The sep parameter should only be a constant string. Single quotes will be escaped. -func StringAggCast(exp, sep string) Field { +func StringAggCast(field Field, sep string) Field { escapedSep := escapeSQLString(sep) - return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')") + return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')") } // StringEscape will wrap field with: diff --git a/pgm_field_test.go b/pgm_field_test.go index 25f0056..3da8828 100644 --- a/pgm_field_test.go +++ b/pgm_field_test.go @@ -102,29 +102,29 @@ func TestConcatWsSQLInjectionPrevention(t *testing.T) { func TestStringAggSQLInjectionPrevention(t *testing.T) { tests := []struct { name string - exp string + field Field sep string contains string }{ { name: "safe parameters", - exp: "column_name", + field: Field("column_name"), sep: ", ", contains: "string_agg(column_name,', ')", }, { name: "escaped quotes in separator", sep: "'; DROP TABLE users; --", - exp: "col", + field: Field("col"), contains: "string_agg(col,'''; DROP TABLE users; --')", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := StringAgg(tt.exp, tt.sep) + result := StringAgg(tt.field, tt.sep) if !strings.Contains(string(result), tt.contains) { - t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains) + t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains) } }) } @@ -132,7 +132,7 @@ func TestStringAggSQLInjectionPrevention(t *testing.T) { // TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes func TestStringAggCastSQLInjectionPrevention(t *testing.T) { - result := StringAggCast("column", "'; DROP TABLE") + result := StringAggCast(Field("column"), "'; DROP TABLE") expected := "string_agg(cast(column as varchar),'''; DROP TABLE')" if string(result) != expected { t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected) @@ -336,3 +336,47 @@ func TestSQLInjectionAttackVectors(t *testing.T) { } }) } + +// TestFieldName tests Field.Name() method +func TestFieldName(t *testing.T) { + tests := []struct { + name string + field Field + expected string + }{ + { + name: "field with table prefix", + field: Field("users.email"), + expected: "email", + }, + { + name: "field with schema and table prefix", + field: Field("public.users.email"), + expected: "email", + }, + { + name: "field without dot", + field: Field("email"), + expected: "email", + }, + { + name: "field with multiple dots", + field: Field("schema.table.column"), + expected: "column", + }, + { + name: "aggregate function", + field: Field("COUNT(users.id)"), + expected: "id)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.field.Name() + if result != tt.expected { + t.Errorf("Field.Name() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/pgm_test.go b/pgm_test.go new file mode 100644 index 0000000..7d03732 --- /dev/null +++ b/pgm_test.go @@ -0,0 +1,706 @@ +// Patial Tech. +// Author, Ankit Patial + +package pgm + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" +) + +// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized +func TestGetPoolNotInitialized(t *testing.T) { + // Save current pool state + currentPool := poolPGX.Load() + + // Clear the pool to simulate uninitialized state + poolPGX.Store(nil) + + // Restore pool after test + defer func() { + poolPGX.Store(currentPool) + }() + + // Verify panic occurs + defer func() { + if r := recover(); r == nil { + t.Error("GetPool should panic when pool not initialized") + } else { + // Check panic message + msg, ok := r.(string) + if !ok || !strings.Contains(msg, "not initialized") { + t.Errorf("Expected panic message about initialization, got: %v", r) + } + } + }() + + GetPool() // Should panic +} + +// TestConfigValidation tests that InitPool validates configuration +func TestConfigValidation(t *testing.T) { + tests := []struct { + name string + config Config + wantPanic bool + panicMsg string + }{ + { + name: "empty connection string", + config: Config{ + ConnString: "", + }, + wantPanic: true, + panicMsg: "connection string is empty", + }, + { + name: "MinConns greater than MaxConns", + config: Config{ + ConnString: "postgres://localhost/test", + MaxConns: 5, + MinConns: 10, + }, + wantPanic: true, + panicMsg: "MinConns", + }, + { + name: "negative MaxConns", + config: Config{ + ConnString: "postgres://localhost/test", + MaxConns: -1, + }, + wantPanic: true, + panicMsg: "negative", + }, + { + name: "negative MinConns", + config: Config{ + ConnString: "postgres://localhost/test", + MinConns: -1, + }, + wantPanic: true, + panicMsg: "negative", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if tt.wantPanic && r == nil { + t.Errorf("InitPool should panic for %s", tt.name) + } + if !tt.wantPanic && r != nil { + t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r) + } + if tt.wantPanic && r != nil { + msg := "" + switch v := r.(type) { + case string: + msg = v + case error: + msg = v.Error() + } + if !strings.Contains(msg, tt.panicMsg) { + t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg) + } + } + }() + + InitPool(tt.config) + }) + } +} + +// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely +func TestClosePoolIdempotent(t *testing.T) { + // Save current pool + currentPool := poolPGX.Load() + defer func() { + poolPGX.Store(currentPool) + }() + + // Set to nil + poolPGX.Store(nil) + + // Should not panic when called on nil pool + ClosePool() + ClosePool() + ClosePool() + + // Should work fine - no panic expected +} + +// TestIsNotFound tests the error checking utility +func TestIsNotFound(t *testing.T) { + // Test with pgx.ErrNoRows + if !IsNotFound(pgx.ErrNoRows) { + t.Error("IsNotFound should return true for pgx.ErrNoRows") + } + + // Test with other errors + otherErr := errors.New("some other error") + if IsNotFound(otherErr) { + t.Error("IsNotFound should return false for non-ErrNoRows errors") + } + + // Test with nil + if IsNotFound(nil) { + t.Error("IsNotFound should return false for nil") + } +} + +// TestPgTime tests PostgreSQL timestamp conversion +func TestPgTime(t *testing.T) { + now := time.Now() + pgTime := PgTime(now) + + if !pgTime.Valid { + t.Error("PgTime should return valid timestamp") + } + + if !pgTime.Time.Equal(now) { + t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now) + } +} + +// TestPgTimeNow tests current time conversion +func TestPgTimeNow(t *testing.T) { + before := time.Now() + pgTime := PgTimeNow() + after := time.Now() + + if !pgTime.Valid { + t.Error("PgTimeNow should return valid timestamp") + } + + // Check that time is between before and after + if pgTime.Time.Before(before) || pgTime.Time.After(after) { + t.Error("PgTimeNow should return current time") + } +} + +// TestTsQueryFunctions tests full-text search query builders +func TestTsQueryFunctions(t *testing.T) { + tests := []struct { + name string + fn func(string) string + input string + expected string + }{ + { + name: "TsAndQuery", + fn: TsAndQuery, + input: "hello world", + expected: "hello & world", + }, + { + name: "TsPrefixAndQuery", + fn: TsPrefixAndQuery, + input: "hello world", + expected: "hello:* & world:*", + }, + { + name: "TsOrQuery", + fn: TsOrQuery, + input: "hello world", + expected: "hello | world", + }, + { + name: "TsPrefixOrQuery", + fn: TsPrefixOrQuery, + input: "hello world", + expected: "hello:* | world:*", + }, + { + name: "TsAndQuery with multiple words", + fn: TsAndQuery, + input: "one two three", + expected: "one & two & three", + }, + { + name: "TsOrQuery with multiple words", + fn: TsOrQuery, + input: "one two three", + expected: "one | two | three", + }, + { + name: "TsAndQuery with extra spaces", + fn: TsAndQuery, + input: "hello world", + expected: "hello & world", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.fn(tt.input) + if result != tt.expected { + t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected) + } + }) + } +} + +// TestFieldsWithSuffix tests the internal suffix function +func TestFieldsWithSuffix(t *testing.T) { + result := fieldsWithSufix("hello world", ":*") + expected := []string{"hello:*", "world:*"} + + if len(result) != len(expected) { + t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected)) + } + + for i, v := range result { + if v != expected[i] { + t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i]) + } + } +} + +// TestBeginTxErrorWrapping tests that BeginTx preserves error context +func TestBeginTxErrorWrapping(t *testing.T) { + // This test verifies basic behavior without requiring a database connection + // Actual error wrapping would require a failing database connection + + // Skip if no pool initialized (unit test environment) + if poolPGX.Load() == nil { + t.Skip("Skipping BeginTx test - requires initialized pool") + } + + // Just verify the function accepts a context + // In a real scenario with a cancelled context, it would return an error + ctx := context.Background() + + // We can't fully test without a real DB connection, so just document behavior + _, err := BeginTx(ctx) + if err != nil { + // Error is expected in test environment without valid DB + t.Logf("BeginTx returned error (expected in test env): %v", err) + } +} + +// TestContextCancellation tests that cancelled context is handled +func TestContextCancellation(t *testing.T) { + // Create an already-cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + // Note: This test would require an actual database connection to verify + // that the context cancellation is properly propagated. We're testing + // that the context is accepted by the function signature. + + // The actual behavior would be tested in integration tests + // For now, we just verify the function doesn't panic with cancelled context + + // Skip if no pool initialized (unit test environment) + if poolPGX.Load() == nil { + t.Skip("Skipping context cancellation test - requires initialized pool") + } + + // This would return an error about cancelled context in real scenario + _, err := BeginTx(ctx) + if err == nil { + t.Log("BeginTx with cancelled context - error expected in real scenario") + } +} + +// TestErrorTypes verifies exported error variables +func TestErrorTypes(t *testing.T) { + if ErrConnStringMissing == nil { + t.Error("ErrConnStringMissing should be defined") + } + + if ErrInitTX == nil { + t.Error("ErrInitTX should be defined") + } + + if ErrCommitTX == nil { + t.Error("ErrCommitTX should be defined") + } + + if ErrNoRows == nil { + t.Error("ErrNoRows should be defined") + } + + // Check error messages are descriptive + if !strings.Contains(ErrConnStringMissing.Error(), "connection string") { + t.Error("ErrConnStringMissing should mention connection string") + } +} + +// TestStringPoolGetPut tests the string builder pool +func TestStringPoolGetPut(t *testing.T) { + // Get a string builder from pool + sb1 := getSB() + if sb1 == nil { + t.Fatal("getSB should return non-nil string builder") + } + + // Use it + sb1.WriteString("test") + if sb1.String() != "test" { + t.Error("String builder should work normally") + } + + // Put it back + putSB(sb1) + + // Get another one - should be reset + sb2 := getSB() + if sb2 == nil { + t.Fatal("getSB should return non-nil string builder") + } + + // Should be empty after reset + if sb2.Len() != 0 { + t.Error("String builder from pool should be reset") + } + + putSB(sb2) +} + +// TestConditionActionTypes tests condition action enum +func TestConditionActionTypes(t *testing.T) { + // Verify enum values are distinct + actions := []CondAction{ + CondActionNothing, + CondActionNeedToClose, + CondActionSubQuery, + } + + seen := make(map[CondAction]bool) + for _, action := range actions { + if seen[action] { + t.Errorf("Duplicate CondAction value: %v", action) + } + seen[action] = true + } +} + +// TestQueryBuilderReuse documents query builder reuse behavior +// This test verifies that query builders accumulate state and should NOT be reused +func TestQueryBuilderReuse(t *testing.T) { + t.Skip("Query builders are not designed for reuse - this test documents the limitation") + + // This test would require actual database tables to demonstrate the issue + // The behavior is: + // 1. Creating a base query builder + // 2. Adding a WHERE clause + // 3. Reusing the same builder with another WHERE clause + // 4. The second query would have BOTH WHERE clauses + // + // Example: + // baseQuery := users.User.Select(user.ID, user.Email) + // baseQuery.Where(user.ID.Eq(1)) // First query + // baseQuery.Where(user.Status.Eq(2)) // Second query has BOTH conditions! + // + // This is by design - query builders are mutable and single-use. + // Each query should create a new builder instance. +} + +// TestQueryBuilderThreadSafety documents that query builders are not thread-safe +func TestQueryBuilderThreadSafety(t *testing.T) { + t.Skip("Query builders are not thread-safe by design - this test documents the limitation") + + // Query builders accumulate state in their internal fields (where, args, etc.) + // and do not use any synchronization primitives. + // + // Sharing a query builder across goroutines will cause race conditions. + // + // CORRECT usage: Create a new query builder in each goroutine + // INCORRECT usage: Share a query builder variable across goroutines + // + // The connection pool itself IS thread-safe and can be used concurrently. +} + +// TestSelectQueryBuilderBasics tests basic query building +func TestSelectQueryBuilderBasics(t *testing.T) { + // Create a test table + testTable := &Table{Name: "users", FieldCount: 10} + idField := Field("users.id") + emailField := Field("users.email") + + // Test basic select + qry := testTable.Select(idField, emailField) + sql := qry.String() + + if !strings.Contains(sql, "SELECT") { + t.Error("Query should contain SELECT") + } + if !strings.Contains(sql, "users.id") { + t.Error("Query should contain users.id field") + } + if !strings.Contains(sql, "users.email") { + t.Error("Query should contain users.email field") + } + if !strings.Contains(sql, "FROM users") { + t.Error("Query should contain FROM users") + } +} + +// TestWhereConditionAccumulation tests that WHERE conditions accumulate +func TestWhereConditionAccumulation(t *testing.T) { + testTable := &Table{Name: "users", FieldCount: 10} + idField := Field("users.id") + statusField := Field("users.status") + + // Create a query and add multiple WHERE conditions + qry := testTable.Select(idField). + Where(idField.Eq(1)). + Where(statusField.Eq("active")) + + sql := qry.String() + + // Both conditions should be present + if !strings.Contains(sql, "WHERE") { + t.Error("Query should contain WHERE clause") + } + + // Should have multiple conditions (this demonstrates accumulation) + if strings.Count(sql, "$") < 2 { + t.Error("Query should have multiple parameterized values") + } +} + +// TestInsertQueryBuilder tests insert query building +func TestInsertQueryBuilder(t *testing.T) { + testTable := &Table{Name: "users", FieldCount: 10} + + qry := testTable.Insert() + + sql := qry.String() + + if !strings.Contains(sql, "INSERT INTO users") { + t.Error("Query should contain INSERT INTO users") + } +} + +// TestUpdateQueryBuilder tests update query building +func TestUpdateQueryBuilder(t *testing.T) { + testTable := &Table{Name: "users", FieldCount: 10} + idField := Field("users.id") + + qry := testTable.Update(). + Where(idField.Eq(1)) + + sql := qry.String() + + if !strings.Contains(sql, "UPDATE users") { + t.Error("Query should contain UPDATE users") + } + if !strings.Contains(sql, "WHERE") { + t.Error("Query should contain WHERE clause") + } +} + +// TestDeleteQueryBuilder tests delete query building +func TestDeleteQueryBuilder(t *testing.T) { + testTable := &Table{Name: "users", FieldCount: 10} + idField := Field("users.id") + + qry := testTable.Delete().Where(idField.Eq(1)) + + sql := qry.String() + + if !strings.Contains(sql, "DELETE FROM users") { + t.Error("Query should contain DELETE FROM users") + } + if !strings.Contains(sql, "WHERE") { + t.Error("Query should contain WHERE clause") + } +} + +// TestFieldConditioners tests various field condition methods +func TestFieldConditioners(t *testing.T) { + field := Field("users.age") + + tests := []struct { + name string + condition Conditioner + checkFunc func(string) bool + }{ + { + name: "Eq", + condition: field.Eq(25), + checkFunc: func(s string) bool { return strings.Contains(s, " = $") }, + }, + { + name: "NotEq", + condition: field.NotEq(25), + checkFunc: func(s string) bool { return strings.Contains(s, " != $") }, + }, + { + name: "Gt", + condition: field.Gt(25), + checkFunc: func(s string) bool { return strings.Contains(s, " > $") }, + }, + { + name: "Lt", + condition: field.Lt(25), + checkFunc: func(s string) bool { return strings.Contains(s, " < $") }, + }, + { + name: "Gte", + condition: field.Gte(25), + checkFunc: func(s string) bool { return strings.Contains(s, " >= $") }, + }, + { + name: "Lte", + condition: field.Lte(25), + checkFunc: func(s string) bool { return strings.Contains(s, " <= $") }, + }, + { + name: "IsNull", + condition: field.IsNull(), + checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") }, + }, + { + name: "IsNotNull", + condition: field.IsNotNull(), + checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build a simple query to see the condition + testTable := &Table{Name: "users", FieldCount: 10} + qry := testTable.Select(field).Where(tt.condition) + sql := qry.String() + + if !tt.checkFunc(sql) { + t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql) + } + }) + } +} + +// TestFieldFunctions tests SQL function wrappers +func TestFieldFunctions(t *testing.T) { + field := Field("users.count") + + tests := []struct { + name string + result Field + contains string + }{ + { + name: "Count", + result: field.Count(), + contains: "COUNT(", + }, + { + name: "Sum", + result: field.Sum(), + contains: "SUM(", + }, + { + name: "Avg", + result: field.Avg(), + contains: "AVG(", + }, + { + name: "Max", + result: field.Max(), + contains: "MAX(", + }, + { + name: "Min", + result: field.Min(), + contains: "Min(", + }, + { + name: "Lower", + result: field.Lower(), + contains: "LOWER(", + }, + { + name: "Upper", + result: field.Upper(), + contains: "UPPER(", + }, + { + name: "Trim", + result: field.Trim(), + contains: "TRIM(", + }, + { + name: "Asc", + result: field.Asc(), + contains: " ASC", + }, + { + name: "Desc", + result: field.Desc(), + contains: " DESC", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resultStr := string(tt.result) + if !strings.Contains(resultStr, tt.contains) { + t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr) + } + }) + } +} + +// TestInsertQueryValidation tests that Insert queries validate fields are set +func TestInsertQueryValidation(t *testing.T) { + testTable := &Table{Name: "users", FieldCount: 10} + idField := Field("users.id") + + t.Run("Exec without fields", func(t *testing.T) { + qry := testTable.Insert() + err := qry.Exec(context.Background()) + if err == nil { + t.Error("Insert.Exec() should return error when no fields set") + } + if !strings.Contains(err.Error(), "no fields to insert") { + t.Errorf("Expected error about no fields, got: %v", err) + } + }) + + t.Run("ExecTx without fields", func(t *testing.T) { + qry := testTable.Insert() + // We can't test ExecTx without a real transaction, but we can test the error is defined + err := qry.ExecTx(context.Background(), nil) + if err == nil { + t.Error("Insert.ExecTx() should return error when no fields set") + } + if !strings.Contains(err.Error(), "no fields to insert") { + t.Errorf("Expected error about no fields, got: %v", err) + } + }) + + t.Run("First without fields", func(t *testing.T) { + qry := testTable.Insert().Returning(idField) + var id int + err := qry.First(context.Background(), &id) + if err == nil { + t.Error("Insert.First() should return error when no fields set") + } + if !strings.Contains(err.Error(), "no fields to insert") { + t.Errorf("Expected error about no fields, got: %v", err) + } + }) + + t.Run("FirstTx without fields", func(t *testing.T) { + qry := testTable.Insert().Returning(idField) + var id int + err := qry.FirstTx(context.Background(), nil, &id) + if err == nil { + t.Error("Insert.FirstTx() should return error when no fields set") + } + if !strings.Contains(err.Error(), "no fields to insert") { + t.Errorf("Expected error about no fields, got: %v", err) + } + }) +} diff --git a/playground/qry_insert_test.go b/playground/qry_insert_test.go index 6564ea5..dfb7591 100644 --- a/playground/qry_insert_test.go +++ b/playground/qry_insert_test.go @@ -36,7 +36,7 @@ func TestInsertQuery2(t *testing.T) { } } -// BenchmarkInsertQuery-12 2014519 584.0 ns/op 1144 B/op 18 allocs/op +// BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op func BenchmarkInsertQuery(b *testing.B) { for b.Loop() { _ = db.User.Insert(). @@ -50,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) { } } -// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op +// BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op func BenchmarkInsertSetMap(b *testing.B) { for b.Loop() { _ = db.User.Insert(). diff --git a/playground/qry_update_test.go b/playground/qry_update_test.go index 0202d38..03ac1f5 100644 --- a/playground/qry_update_test.go +++ b/playground/qry_update_test.go @@ -43,7 +43,7 @@ func TestUpdateQueryValidation(t *testing.T) { } } -// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op +// BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op func BenchmarkUpdateQuery(b *testing.B) { for b.Loop() { _ = db.User.Update(). diff --git a/qry_insert.go b/qry_insert.go index 3490920..e240dac 100644 --- a/qry_insert.go +++ b/qry_insert.go @@ -5,7 +5,7 @@ package pgm import ( "context" - "fmt" + "errors" "strconv" "strings" @@ -84,21 +84,28 @@ func (q *insertQry) DoNothing() Execute { } func (q *insertQry) DoUpdate(fields ...Field) Execute { - var sb strings.Builder + sb := getSB() + defer putSB(sb) + + sb.WriteString("DO UPDATE SET ") for i, f := range fields { col := f.Name() - if i == 0 { - fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col) - } else { - fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col) + if i > 0 { + sb.WriteString(", ") } + sb.WriteString(col) + sb.WriteString(" = EXCLUDED.") + sb.WriteString(col) } - q.conflictAction = "DO UPDATE SET " + sb.String() + q.conflictAction = sb.String() return q } func (q *insertQry) Exec(ctx context.Context) error { + if len(q.fields) == 0 { + return errors.New("insert query has no fields to insert: call Set() before Exec()") + } _, err := poolPGX.Load().Exec(ctx, q.String(), q.args...) if err != nil { return err @@ -107,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error { } func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error { + if len(q.fields) == 0 { + return errors.New("insert query has no fields to insert: call Set() before ExecTx()") + } _, err := tx.Exec(ctx, q.String(), q.args...) if err != nil { return err @@ -116,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error { } func (q *insertQry) First(ctx context.Context, dest ...any) error { + if len(q.fields) == 0 { + return errors.New("insert query has no fields to insert: call Set() before First()") + } return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...) } func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error { + if len(q.fields) == 0 { + return errors.New("insert query has no fields to insert: call Set() before FirstTx()") + } return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...) } diff --git a/qry_select.go b/qry_select.go index 743e053..c9bc8ac 100644 --- a/qry_select.go +++ b/qry_select.go @@ -203,7 +203,8 @@ func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, co defer putSB(sb) sb.Grow(len(str) * 2) - sb.WriteString(str + " AND ") + sb.WriteString(str) + sb.WriteString(" AND ") var argIdx int for i, c := range cond { @@ -263,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error { func (q *selectQry) All(ctx context.Context, row RowsCb) error { rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...) - if errors.Is(err, pgx.ErrNoRows) { - return ErrNoRows + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrNoRows + } + return err } + defer rows.Close() for rows.Next() { @@ -274,13 +279,21 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error { } } + // Check for errors from iteration + if err := rows.Err(); err != nil { + return err + } + return nil } func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error { rows, err := tx.Query(ctx, q.String(), q.args...) - if errors.Is(err, pgx.ErrNoRows) { - return ErrNoRows + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrNoRows + } + return err } defer rows.Close() @@ -290,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error { } } + // Check for errors from iteration + if err := rows.Err(); err != nil { + return err + } + return nil }