perf enhancement

This commit is contained in:
2025-11-16 16:21:35 +05:30
parent 29cddb6389
commit 1d9d9d9308
9 changed files with 917 additions and 35 deletions

View File

@@ -172,6 +172,56 @@ func main() {
} }
``` ```
## Important: Query Builder Lifecycle
⚠️ **Query builders are single-use and should not be reused:**
```go
// ❌ WRONG - Don't reuse query builders
baseQuery := users.User.Select(users.ID, users.Email)
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
// Second query has BOTH WHERE clauses - incorrect behavior!
// ✅ CORRECT - Create new query each time
users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
```
**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up.
### Thread Safety
⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines:
```go
// ✅ CORRECT - Each goroutine creates its own query
for i := 0; i < 10; i++ {
go func(id int) {
var email string
err := users.User.Select(users.Email).
Where(users.ID.Eq(id)).
First(ctx, &email)
// Process result...
}(i)
}
// ❌ WRONG - Sharing query builder across goroutines
baseQuery := users.User.Select(users.Email)
for i := 0; i < 10; i++ {
go func(id int) {
var email string
baseQuery.Where(users.ID.Eq(id)).First(ctx, &email)
// RACE CONDITION! Multiple goroutines modifying shared state
}(i)
}
```
**Thread-Safe Components:**
- ✅ Connection Pool - Safe for concurrent use
- ✅ Table objects - Safe to share
- ❌ Query builders - Create new instance per goroutine
## Usage Examples ## Usage Examples
### SELECT Queries ### SELECT Queries

57
pgm.go
View File

@@ -25,13 +25,15 @@ var (
ErrConnStringMissing = errors.New("connection string is empty") ErrConnStringMissing = errors.New("connection string is empty")
) )
// Errors // Common errors returned by pgm operations
var ( var (
ErrInitTX = errors.New("failed to init db.tx") ErrInitTX = errors.New("failed to init db.tx")
ErrCommitTX = errors.New("failed to commit db.tx") ErrCommitTX = errors.New("failed to commit db.tx")
ErrNoRows = errors.New("no data found") ErrNoRows = errors.New("no data found")
) )
// Config holds the configuration for initializing the connection pool.
// All fields except ConnString are optional and will use pgx defaults if not set.
type Config struct { type Config struct {
ConnString string ConnString string
MaxConns int32 MaxConns int32
@@ -40,7 +42,17 @@ type Config struct {
MaxConnIdleTime time.Duration MaxConnIdleTime time.Duration
} }
// InitPool will create new pgxpool.Pool and will keep it for its working // InitPool initializes the connection pool with the provided configuration.
// It validates the configuration and panics if invalid.
// This function should be called once at application startup.
//
// Example:
//
// pgm.InitPool(pgm.Config{
// ConnString: "postgres://user:pass@localhost/dbname",
// MaxConns: 100,
// MinConns: 5,
// })
func InitPool(conf Config) { func InitPool(conf Config) {
if conf.ConnString == "" { if conf.ConnString == "" {
panic(ErrConnStringMissing) panic(ErrConnStringMissing)
@@ -88,9 +100,15 @@ func InitPool(conf Config) {
poolPGX.Store(p) poolPGX.Store(p)
} }
// GetPool instance // GetPool returns the initialized connection pool instance.
// It panics with a descriptive message if InitPool() has not been called.
// This is a fail-fast approach to catch programming errors early.
func GetPool() *pgxpool.Pool { func GetPool() *pgxpool.Pool {
return poolPGX.Load() p := poolPGX.Load()
if p == nil {
panic("pgm: connection pool not initialized, call InitPool() first")
}
return p
} }
// ClosePool closes the connection pool gracefully. // ClosePool closes the connection pool gracefully.
@@ -102,7 +120,21 @@ func ClosePool() {
} }
} }
// BeginTx begins a pgx poll transaction // BeginTx begins a new database transaction from the connection pool.
// Returns an error if the transaction cannot be started.
// Remember to commit or rollback the transaction when done.
//
// Example:
//
// tx, err := pgm.BeginTx(ctx)
// if err != nil {
// return err
// }
// defer tx.Rollback(ctx) // rollback on error
//
// // ... do work ...
//
// return tx.Commit(ctx)
func BeginTx(ctx context.Context) (pgx.Tx, error) { func BeginTx(ctx context.Context) (pgx.Tx, error) {
tx, err := poolPGX.Load().Begin(ctx) tx, err := poolPGX.Load().Begin(ctx)
if err != nil { if err != nil {
@@ -113,32 +145,43 @@ func BeginTx(ctx context.Context) (pgx.Tx, error) {
return tx, nil return tx, nil
} }
// IsNotFound error check // IsNotFound checks if an error is a "no rows" error from pgx.
// Returns true if the error indicates no rows were found in a query result.
func IsNotFound(err error) bool { func IsNotFound(err error) bool {
return errors.Is(err, pgx.ErrNoRows) return errors.Is(err, pgx.ErrNoRows)
} }
// PgTime as in UTC // PgTime converts a Go time.Time to PostgreSQL timestamptz type.
// The time is stored as-is (preserves timezone information).
func PgTime(t time.Time) pgtype.Timestamptz { func PgTime(t time.Time) pgtype.Timestamptz {
return pgtype.Timestamptz{Time: t, Valid: true} return pgtype.Timestamptz{Time: t, Valid: true}
} }
// PgTimeNow returns the current time as PostgreSQL timestamptz type.
func PgTimeNow() pgtype.Timestamptz { func PgTimeNow() pgtype.Timestamptz {
return pgtype.Timestamptz{Time: time.Now(), Valid: true} return pgtype.Timestamptz{Time: time.Now(), Valid: true}
} }
// TsAndQuery converts a text search query to use AND operator between terms.
// Example: "hello world" becomes "hello & world"
func TsAndQuery(q string) string { func TsAndQuery(q string) string {
return strings.Join(strings.Fields(q), " & ") return strings.Join(strings.Fields(q), " & ")
} }
// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching.
// Example: "hello world" becomes "hello:* & world:*"
func TsPrefixAndQuery(q string) string { func TsPrefixAndQuery(q string) string {
return strings.Join(fieldsWithSufix(q, ":*"), " & ") return strings.Join(fieldsWithSufix(q, ":*"), " & ")
} }
// TsOrQuery converts a text search query to use OR operator between terms.
// Example: "hello world" becomes "hello | world"
func TsOrQuery(q string) string { func TsOrQuery(q string) string {
return strings.Join(strings.Fields(q), " | ") return strings.Join(strings.Fields(q), " | ")
} }
// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching.
// Example: "hello world" becomes "hello:* | world:*"
func TsPrefixOrQuery(q string) string { func TsPrefixOrQuery(q string) string {
return strings.Join(fieldsWithSufix(q, ":*"), " | ") return strings.Join(fieldsWithSufix(q, ":*"), " | ")
} }

View File

@@ -45,7 +45,12 @@ func escapeSQLString(s string) string {
} }
func (f Field) Name() string { func (f Field) Name() string {
return strings.Split(string(f), ".")[1] s := string(f)
idx := strings.LastIndexByte(s, '.')
if idx == -1 {
return s // Return as-is if no dot
}
return s[idx+1:] // Return part after last dot
} }
func (f Field) String() string { func (f Field) String() string {
@@ -66,19 +71,19 @@ func ConcatWs(sep string, fields ...Field) Field {
} }
// StringAgg creates a STRING_AGG SQL function. // StringAgg creates a STRING_AGG SQL function.
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL. // SECURITY: The field parameter provides type safety by accepting only Field type.
// The sep parameter should only be a constant string. Single quotes will be escaped. // The sep parameter should only be a constant string. Single quotes will be escaped.
func StringAgg(exp, sep string) Field { func StringAgg(field Field, sep string) Field {
escapedSep := escapeSQLString(sep) escapedSep := escapeSQLString(sep)
return Field("string_agg(" + exp + ",'" + escapedSep + "')") return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
} }
// StringAggCast creates a STRING_AGG SQL function with cast to varchar. // StringAggCast creates a STRING_AGG SQL function with cast to varchar.
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL. // SECURITY: The field parameter provides type safety by accepting only Field type.
// The sep parameter should only be a constant string. Single quotes will be escaped. // The sep parameter should only be a constant string. Single quotes will be escaped.
func StringAggCast(exp, sep string) Field { func StringAggCast(field Field, sep string) Field {
escapedSep := escapeSQLString(sep) escapedSep := escapeSQLString(sep)
return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')") return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
} }
// StringEscape will wrap field with: // StringEscape will wrap field with:

View File

@@ -102,29 +102,29 @@ func TestConcatWsSQLInjectionPrevention(t *testing.T) {
func TestStringAggSQLInjectionPrevention(t *testing.T) { func TestStringAggSQLInjectionPrevention(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
exp string field Field
sep string sep string
contains string contains string
}{ }{
{ {
name: "safe parameters", name: "safe parameters",
exp: "column_name", field: Field("column_name"),
sep: ", ", sep: ", ",
contains: "string_agg(column_name,', ')", contains: "string_agg(column_name,', ')",
}, },
{ {
name: "escaped quotes in separator", name: "escaped quotes in separator",
sep: "'; DROP TABLE users; --", sep: "'; DROP TABLE users; --",
exp: "col", field: Field("col"),
contains: "string_agg(col,'''; DROP TABLE users; --')", contains: "string_agg(col,'''; DROP TABLE users; --')",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := StringAgg(tt.exp, tt.sep) result := StringAgg(tt.field, tt.sep)
if !strings.Contains(string(result), tt.contains) { if !strings.Contains(string(result), tt.contains) {
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains) t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
} }
}) })
} }
@@ -132,7 +132,7 @@ func TestStringAggSQLInjectionPrevention(t *testing.T) {
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes // TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
func TestStringAggCastSQLInjectionPrevention(t *testing.T) { func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
result := StringAggCast("column", "'; DROP TABLE") result := StringAggCast(Field("column"), "'; DROP TABLE")
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')" expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
if string(result) != expected { if string(result) != expected {
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected) t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
@@ -336,3 +336,47 @@ func TestSQLInjectionAttackVectors(t *testing.T) {
} }
}) })
} }
// TestFieldName tests Field.Name() method
func TestFieldName(t *testing.T) {
tests := []struct {
name string
field Field
expected string
}{
{
name: "field with table prefix",
field: Field("users.email"),
expected: "email",
},
{
name: "field with schema and table prefix",
field: Field("public.users.email"),
expected: "email",
},
{
name: "field without dot",
field: Field("email"),
expected: "email",
},
{
name: "field with multiple dots",
field: Field("schema.table.column"),
expected: "column",
},
{
name: "aggregate function",
field: Field("COUNT(users.id)"),
expected: "id)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.field.Name()
if result != tt.expected {
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
}
})
}
}

706
pgm_test.go Normal file
View File

@@ -0,0 +1,706 @@
// Patial Tech.
// Author, Ankit Patial
package pgm
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5"
)
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
func TestGetPoolNotInitialized(t *testing.T) {
// Save current pool state
currentPool := poolPGX.Load()
// Clear the pool to simulate uninitialized state
poolPGX.Store(nil)
// Restore pool after test
defer func() {
poolPGX.Store(currentPool)
}()
// Verify panic occurs
defer func() {
if r := recover(); r == nil {
t.Error("GetPool should panic when pool not initialized")
} else {
// Check panic message
msg, ok := r.(string)
if !ok || !strings.Contains(msg, "not initialized") {
t.Errorf("Expected panic message about initialization, got: %v", r)
}
}
}()
GetPool() // Should panic
}
// TestConfigValidation tests that InitPool validates configuration
func TestConfigValidation(t *testing.T) {
tests := []struct {
name string
config Config
wantPanic bool
panicMsg string
}{
{
name: "empty connection string",
config: Config{
ConnString: "",
},
wantPanic: true,
panicMsg: "connection string is empty",
},
{
name: "MinConns greater than MaxConns",
config: Config{
ConnString: "postgres://localhost/test",
MaxConns: 5,
MinConns: 10,
},
wantPanic: true,
panicMsg: "MinConns",
},
{
name: "negative MaxConns",
config: Config{
ConnString: "postgres://localhost/test",
MaxConns: -1,
},
wantPanic: true,
panicMsg: "negative",
},
{
name: "negative MinConns",
config: Config{
ConnString: "postgres://localhost/test",
MinConns: -1,
},
wantPanic: true,
panicMsg: "negative",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
r := recover()
if tt.wantPanic && r == nil {
t.Errorf("InitPool should panic for %s", tt.name)
}
if !tt.wantPanic && r != nil {
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
}
if tt.wantPanic && r != nil {
msg := ""
switch v := r.(type) {
case string:
msg = v
case error:
msg = v.Error()
}
if !strings.Contains(msg, tt.panicMsg) {
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
}
}
}()
InitPool(tt.config)
})
}
}
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
func TestClosePoolIdempotent(t *testing.T) {
// Save current pool
currentPool := poolPGX.Load()
defer func() {
poolPGX.Store(currentPool)
}()
// Set to nil
poolPGX.Store(nil)
// Should not panic when called on nil pool
ClosePool()
ClosePool()
ClosePool()
// Should work fine - no panic expected
}
// TestIsNotFound tests the error checking utility
func TestIsNotFound(t *testing.T) {
// Test with pgx.ErrNoRows
if !IsNotFound(pgx.ErrNoRows) {
t.Error("IsNotFound should return true for pgx.ErrNoRows")
}
// Test with other errors
otherErr := errors.New("some other error")
if IsNotFound(otherErr) {
t.Error("IsNotFound should return false for non-ErrNoRows errors")
}
// Test with nil
if IsNotFound(nil) {
t.Error("IsNotFound should return false for nil")
}
}
// TestPgTime tests PostgreSQL timestamp conversion
func TestPgTime(t *testing.T) {
now := time.Now()
pgTime := PgTime(now)
if !pgTime.Valid {
t.Error("PgTime should return valid timestamp")
}
if !pgTime.Time.Equal(now) {
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
}
}
// TestPgTimeNow tests current time conversion
func TestPgTimeNow(t *testing.T) {
before := time.Now()
pgTime := PgTimeNow()
after := time.Now()
if !pgTime.Valid {
t.Error("PgTimeNow should return valid timestamp")
}
// Check that time is between before and after
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
t.Error("PgTimeNow should return current time")
}
}
// TestTsQueryFunctions tests full-text search query builders
func TestTsQueryFunctions(t *testing.T) {
tests := []struct {
name string
fn func(string) string
input string
expected string
}{
{
name: "TsAndQuery",
fn: TsAndQuery,
input: "hello world",
expected: "hello & world",
},
{
name: "TsPrefixAndQuery",
fn: TsPrefixAndQuery,
input: "hello world",
expected: "hello:* & world:*",
},
{
name: "TsOrQuery",
fn: TsOrQuery,
input: "hello world",
expected: "hello | world",
},
{
name: "TsPrefixOrQuery",
fn: TsPrefixOrQuery,
input: "hello world",
expected: "hello:* | world:*",
},
{
name: "TsAndQuery with multiple words",
fn: TsAndQuery,
input: "one two three",
expected: "one & two & three",
},
{
name: "TsOrQuery with multiple words",
fn: TsOrQuery,
input: "one two three",
expected: "one | two | three",
},
{
name: "TsAndQuery with extra spaces",
fn: TsAndQuery,
input: "hello world",
expected: "hello & world",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.fn(tt.input)
if result != tt.expected {
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
}
})
}
}
// TestFieldsWithSuffix tests the internal suffix function
func TestFieldsWithSuffix(t *testing.T) {
result := fieldsWithSufix("hello world", ":*")
expected := []string{"hello:*", "world:*"}
if len(result) != len(expected) {
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
}
for i, v := range result {
if v != expected[i] {
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
}
}
}
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
func TestBeginTxErrorWrapping(t *testing.T) {
// This test verifies basic behavior without requiring a database connection
// Actual error wrapping would require a failing database connection
// Skip if no pool initialized (unit test environment)
if poolPGX.Load() == nil {
t.Skip("Skipping BeginTx test - requires initialized pool")
}
// Just verify the function accepts a context
// In a real scenario with a cancelled context, it would return an error
ctx := context.Background()
// We can't fully test without a real DB connection, so just document behavior
_, err := BeginTx(ctx)
if err != nil {
// Error is expected in test environment without valid DB
t.Logf("BeginTx returned error (expected in test env): %v", err)
}
}
// TestContextCancellation tests that cancelled context is handled
func TestContextCancellation(t *testing.T) {
// Create an already-cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
// Note: This test would require an actual database connection to verify
// that the context cancellation is properly propagated. We're testing
// that the context is accepted by the function signature.
// The actual behavior would be tested in integration tests
// For now, we just verify the function doesn't panic with cancelled context
// Skip if no pool initialized (unit test environment)
if poolPGX.Load() == nil {
t.Skip("Skipping context cancellation test - requires initialized pool")
}
// This would return an error about cancelled context in real scenario
_, err := BeginTx(ctx)
if err == nil {
t.Log("BeginTx with cancelled context - error expected in real scenario")
}
}
// TestErrorTypes verifies exported error variables
func TestErrorTypes(t *testing.T) {
if ErrConnStringMissing == nil {
t.Error("ErrConnStringMissing should be defined")
}
if ErrInitTX == nil {
t.Error("ErrInitTX should be defined")
}
if ErrCommitTX == nil {
t.Error("ErrCommitTX should be defined")
}
if ErrNoRows == nil {
t.Error("ErrNoRows should be defined")
}
// Check error messages are descriptive
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
t.Error("ErrConnStringMissing should mention connection string")
}
}
// TestStringPoolGetPut tests the string builder pool
func TestStringPoolGetPut(t *testing.T) {
// Get a string builder from pool
sb1 := getSB()
if sb1 == nil {
t.Fatal("getSB should return non-nil string builder")
}
// Use it
sb1.WriteString("test")
if sb1.String() != "test" {
t.Error("String builder should work normally")
}
// Put it back
putSB(sb1)
// Get another one - should be reset
sb2 := getSB()
if sb2 == nil {
t.Fatal("getSB should return non-nil string builder")
}
// Should be empty after reset
if sb2.Len() != 0 {
t.Error("String builder from pool should be reset")
}
putSB(sb2)
}
// TestConditionActionTypes tests condition action enum
func TestConditionActionTypes(t *testing.T) {
// Verify enum values are distinct
actions := []CondAction{
CondActionNothing,
CondActionNeedToClose,
CondActionSubQuery,
}
seen := make(map[CondAction]bool)
for _, action := range actions {
if seen[action] {
t.Errorf("Duplicate CondAction value: %v", action)
}
seen[action] = true
}
}
// TestQueryBuilderReuse documents query builder reuse behavior
// This test verifies that query builders accumulate state and should NOT be reused
func TestQueryBuilderReuse(t *testing.T) {
t.Skip("Query builders are not designed for reuse - this test documents the limitation")
// This test would require actual database tables to demonstrate the issue
// The behavior is:
// 1. Creating a base query builder
// 2. Adding a WHERE clause
// 3. Reusing the same builder with another WHERE clause
// 4. The second query would have BOTH WHERE clauses
//
// Example:
// baseQuery := users.User.Select(user.ID, user.Email)
// baseQuery.Where(user.ID.Eq(1)) // First query
// baseQuery.Where(user.Status.Eq(2)) // Second query has BOTH conditions!
//
// This is by design - query builders are mutable and single-use.
// Each query should create a new builder instance.
}
// TestQueryBuilderThreadSafety documents that query builders are not thread-safe
func TestQueryBuilderThreadSafety(t *testing.T) {
t.Skip("Query builders are not thread-safe by design - this test documents the limitation")
// Query builders accumulate state in their internal fields (where, args, etc.)
// and do not use any synchronization primitives.
//
// Sharing a query builder across goroutines will cause race conditions.
//
// CORRECT usage: Create a new query builder in each goroutine
// INCORRECT usage: Share a query builder variable across goroutines
//
// The connection pool itself IS thread-safe and can be used concurrently.
}
// TestSelectQueryBuilderBasics tests basic query building
func TestSelectQueryBuilderBasics(t *testing.T) {
// Create a test table
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
emailField := Field("users.email")
// Test basic select
qry := testTable.Select(idField, emailField)
sql := qry.String()
if !strings.Contains(sql, "SELECT") {
t.Error("Query should contain SELECT")
}
if !strings.Contains(sql, "users.id") {
t.Error("Query should contain users.id field")
}
if !strings.Contains(sql, "users.email") {
t.Error("Query should contain users.email field")
}
if !strings.Contains(sql, "FROM users") {
t.Error("Query should contain FROM users")
}
}
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
func TestWhereConditionAccumulation(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
statusField := Field("users.status")
// Create a query and add multiple WHERE conditions
qry := testTable.Select(idField).
Where(idField.Eq(1)).
Where(statusField.Eq("active"))
sql := qry.String()
// Both conditions should be present
if !strings.Contains(sql, "WHERE") {
t.Error("Query should contain WHERE clause")
}
// Should have multiple conditions (this demonstrates accumulation)
if strings.Count(sql, "$") < 2 {
t.Error("Query should have multiple parameterized values")
}
}
// TestInsertQueryBuilder tests insert query building
func TestInsertQueryBuilder(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
qry := testTable.Insert()
sql := qry.String()
if !strings.Contains(sql, "INSERT INTO users") {
t.Error("Query should contain INSERT INTO users")
}
}
// TestUpdateQueryBuilder tests update query building
func TestUpdateQueryBuilder(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
qry := testTable.Update().
Where(idField.Eq(1))
sql := qry.String()
if !strings.Contains(sql, "UPDATE users") {
t.Error("Query should contain UPDATE users")
}
if !strings.Contains(sql, "WHERE") {
t.Error("Query should contain WHERE clause")
}
}
// TestDeleteQueryBuilder tests delete query building
func TestDeleteQueryBuilder(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
qry := testTable.Delete().Where(idField.Eq(1))
sql := qry.String()
if !strings.Contains(sql, "DELETE FROM users") {
t.Error("Query should contain DELETE FROM users")
}
if !strings.Contains(sql, "WHERE") {
t.Error("Query should contain WHERE clause")
}
}
// TestFieldConditioners tests various field condition methods
func TestFieldConditioners(t *testing.T) {
field := Field("users.age")
tests := []struct {
name string
condition Conditioner
checkFunc func(string) bool
}{
{
name: "Eq",
condition: field.Eq(25),
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
},
{
name: "NotEq",
condition: field.NotEq(25),
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
},
{
name: "Gt",
condition: field.Gt(25),
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
},
{
name: "Lt",
condition: field.Lt(25),
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
},
{
name: "Gte",
condition: field.Gte(25),
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
},
{
name: "Lte",
condition: field.Lte(25),
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
},
{
name: "IsNull",
condition: field.IsNull(),
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
},
{
name: "IsNotNull",
condition: field.IsNotNull(),
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Build a simple query to see the condition
testTable := &Table{Name: "users", FieldCount: 10}
qry := testTable.Select(field).Where(tt.condition)
sql := qry.String()
if !tt.checkFunc(sql) {
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
}
})
}
}
// TestFieldFunctions tests SQL function wrappers
func TestFieldFunctions(t *testing.T) {
field := Field("users.count")
tests := []struct {
name string
result Field
contains string
}{
{
name: "Count",
result: field.Count(),
contains: "COUNT(",
},
{
name: "Sum",
result: field.Sum(),
contains: "SUM(",
},
{
name: "Avg",
result: field.Avg(),
contains: "AVG(",
},
{
name: "Max",
result: field.Max(),
contains: "MAX(",
},
{
name: "Min",
result: field.Min(),
contains: "Min(",
},
{
name: "Lower",
result: field.Lower(),
contains: "LOWER(",
},
{
name: "Upper",
result: field.Upper(),
contains: "UPPER(",
},
{
name: "Trim",
result: field.Trim(),
contains: "TRIM(",
},
{
name: "Asc",
result: field.Asc(),
contains: " ASC",
},
{
name: "Desc",
result: field.Desc(),
contains: " DESC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resultStr := string(tt.result)
if !strings.Contains(resultStr, tt.contains) {
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
}
})
}
}
// TestInsertQueryValidation tests that Insert queries validate fields are set
func TestInsertQueryValidation(t *testing.T) {
testTable := &Table{Name: "users", FieldCount: 10}
idField := Field("users.id")
t.Run("Exec without fields", func(t *testing.T) {
qry := testTable.Insert()
err := qry.Exec(context.Background())
if err == nil {
t.Error("Insert.Exec() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
t.Run("ExecTx without fields", func(t *testing.T) {
qry := testTable.Insert()
// We can't test ExecTx without a real transaction, but we can test the error is defined
err := qry.ExecTx(context.Background(), nil)
if err == nil {
t.Error("Insert.ExecTx() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
t.Run("First without fields", func(t *testing.T) {
qry := testTable.Insert().Returning(idField)
var id int
err := qry.First(context.Background(), &id)
if err == nil {
t.Error("Insert.First() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
t.Run("FirstTx without fields", func(t *testing.T) {
qry := testTable.Insert().Returning(idField)
var id int
err := qry.FirstTx(context.Background(), nil, &id)
if err == nil {
t.Error("Insert.FirstTx() should return error when no fields set")
}
if !strings.Contains(err.Error(), "no fields to insert") {
t.Errorf("Expected error about no fields, got: %v", err)
}
})
}

View File

@@ -36,7 +36,7 @@ func TestInsertQuery2(t *testing.T) {
} }
} }
// BenchmarkInsertQuery-12 2014519 584.0 ns/op 1144 B/op 18 allocs/op // BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op
func BenchmarkInsertQuery(b *testing.B) { func BenchmarkInsertQuery(b *testing.B) {
for b.Loop() { for b.Loop() {
_ = db.User.Insert(). _ = db.User.Insert().
@@ -50,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) {
} }
} }
// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op // BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op
func BenchmarkInsertSetMap(b *testing.B) { func BenchmarkInsertSetMap(b *testing.B) {
for b.Loop() { for b.Loop() {
_ = db.User.Insert(). _ = db.User.Insert().

View File

@@ -43,7 +43,7 @@ func TestUpdateQueryValidation(t *testing.T) {
} }
} }
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op // BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op
func BenchmarkUpdateQuery(b *testing.B) { func BenchmarkUpdateQuery(b *testing.B) {
for b.Loop() { for b.Loop() {
_ = db.User.Update(). _ = db.User.Update().

View File

@@ -5,7 +5,7 @@ package pgm
import ( import (
"context" "context"
"fmt" "errors"
"strconv" "strconv"
"strings" "strings"
@@ -84,21 +84,28 @@ func (q *insertQry) DoNothing() Execute {
} }
func (q *insertQry) DoUpdate(fields ...Field) Execute { func (q *insertQry) DoUpdate(fields ...Field) Execute {
var sb strings.Builder sb := getSB()
defer putSB(sb)
sb.WriteString("DO UPDATE SET ")
for i, f := range fields { for i, f := range fields {
col := f.Name() col := f.Name()
if i == 0 { if i > 0 {
fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col) sb.WriteString(", ")
} else {
fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col)
} }
sb.WriteString(col)
sb.WriteString(" = EXCLUDED.")
sb.WriteString(col)
} }
q.conflictAction = "DO UPDATE SET " + sb.String() q.conflictAction = sb.String()
return q return q
} }
func (q *insertQry) Exec(ctx context.Context) error { func (q *insertQry) Exec(ctx context.Context) error {
if len(q.fields) == 0 {
return errors.New("insert query has no fields to insert: call Set() before Exec()")
}
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...) _, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
if err != nil { if err != nil {
return err return err
@@ -107,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error {
} }
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error { func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
if len(q.fields) == 0 {
return errors.New("insert query has no fields to insert: call Set() before ExecTx()")
}
_, err := tx.Exec(ctx, q.String(), q.args...) _, err := tx.Exec(ctx, q.String(), q.args...)
if err != nil { if err != nil {
return err return err
@@ -116,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
} }
func (q *insertQry) First(ctx context.Context, dest ...any) error { func (q *insertQry) First(ctx context.Context, dest ...any) error {
if len(q.fields) == 0 {
return errors.New("insert query has no fields to insert: call Set() before First()")
}
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...) return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
} }
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error { func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
if len(q.fields) == 0 {
return errors.New("insert query has no fields to insert: call Set() before FirstTx()")
}
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...) return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
} }

View File

@@ -203,7 +203,8 @@ func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, co
defer putSB(sb) defer putSB(sb)
sb.Grow(len(str) * 2) sb.Grow(len(str) * 2)
sb.WriteString(str + " AND ") sb.WriteString(str)
sb.WriteString(" AND ")
var argIdx int var argIdx int
for i, c := range cond { for i, c := range cond {
@@ -263,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
func (q *selectQry) All(ctx context.Context, row RowsCb) error { func (q *selectQry) All(ctx context.Context, row RowsCb) error {
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...) rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return ErrNoRows return ErrNoRows
} }
return err
}
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
@@ -274,14 +279,22 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error {
} }
} }
// Check for errors from iteration
if err := rows.Err(); err != nil {
return err
}
return nil return nil
} }
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error { func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
rows, err := tx.Query(ctx, q.String(), q.args...) rows, err := tx.Query(ctx, q.String(), q.args...)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return ErrNoRows return ErrNoRows
} }
return err
}
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
@@ -290,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
} }
} }
// Check for errors from iteration
if err := rows.Err(); err != nil {
return err
}
return nil return nil
} }