forked from go/pgm
perf enhancement
This commit is contained in:
50
README.md
50
README.md
@@ -172,6 +172,56 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
## Important: Query Builder Lifecycle
|
||||
|
||||
⚠️ **Query builders are single-use and should not be reused:**
|
||||
|
||||
```go
|
||||
// ❌ WRONG - Don't reuse query builders
|
||||
baseQuery := users.User.Select(users.ID, users.Email)
|
||||
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||
// Second query has BOTH WHERE clauses - incorrect behavior!
|
||||
|
||||
// ✅ CORRECT - Create new query each time
|
||||
users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||
users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||
```
|
||||
|
||||
**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up.
|
||||
|
||||
### Thread Safety
|
||||
|
||||
⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines:
|
||||
|
||||
```go
|
||||
// ✅ CORRECT - Each goroutine creates its own query
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
var email string
|
||||
err := users.User.Select(users.Email).
|
||||
Where(users.ID.Eq(id)).
|
||||
First(ctx, &email)
|
||||
// Process result...
|
||||
}(i)
|
||||
}
|
||||
|
||||
// ❌ WRONG - Sharing query builder across goroutines
|
||||
baseQuery := users.User.Select(users.Email)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
var email string
|
||||
baseQuery.Where(users.ID.Eq(id)).First(ctx, &email)
|
||||
// RACE CONDITION! Multiple goroutines modifying shared state
|
||||
}(i)
|
||||
}
|
||||
```
|
||||
|
||||
**Thread-Safe Components:**
|
||||
- ✅ Connection Pool - Safe for concurrent use
|
||||
- ✅ Table objects - Safe to share
|
||||
- ❌ Query builders - Create new instance per goroutine
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### SELECT Queries
|
||||
|
||||
57
pgm.go
57
pgm.go
@@ -25,13 +25,15 @@ var (
|
||||
ErrConnStringMissing = errors.New("connection string is empty")
|
||||
)
|
||||
|
||||
// Errors
|
||||
// Common errors returned by pgm operations
|
||||
var (
|
||||
ErrInitTX = errors.New("failed to init db.tx")
|
||||
ErrCommitTX = errors.New("failed to commit db.tx")
|
||||
ErrNoRows = errors.New("no data found")
|
||||
)
|
||||
|
||||
// Config holds the configuration for initializing the connection pool.
|
||||
// All fields except ConnString are optional and will use pgx defaults if not set.
|
||||
type Config struct {
|
||||
ConnString string
|
||||
MaxConns int32
|
||||
@@ -40,7 +42,17 @@ type Config struct {
|
||||
MaxConnIdleTime time.Duration
|
||||
}
|
||||
|
||||
// InitPool will create new pgxpool.Pool and will keep it for its working
|
||||
// InitPool initializes the connection pool with the provided configuration.
|
||||
// It validates the configuration and panics if invalid.
|
||||
// This function should be called once at application startup.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// pgm.InitPool(pgm.Config{
|
||||
// ConnString: "postgres://user:pass@localhost/dbname",
|
||||
// MaxConns: 100,
|
||||
// MinConns: 5,
|
||||
// })
|
||||
func InitPool(conf Config) {
|
||||
if conf.ConnString == "" {
|
||||
panic(ErrConnStringMissing)
|
||||
@@ -88,9 +100,15 @@ func InitPool(conf Config) {
|
||||
poolPGX.Store(p)
|
||||
}
|
||||
|
||||
// GetPool instance
|
||||
// GetPool returns the initialized connection pool instance.
|
||||
// It panics with a descriptive message if InitPool() has not been called.
|
||||
// This is a fail-fast approach to catch programming errors early.
|
||||
func GetPool() *pgxpool.Pool {
|
||||
return poolPGX.Load()
|
||||
p := poolPGX.Load()
|
||||
if p == nil {
|
||||
panic("pgm: connection pool not initialized, call InitPool() first")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ClosePool closes the connection pool gracefully.
|
||||
@@ -102,7 +120,21 @@ func ClosePool() {
|
||||
}
|
||||
}
|
||||
|
||||
// BeginTx begins a pgx poll transaction
|
||||
// BeginTx begins a new database transaction from the connection pool.
|
||||
// Returns an error if the transaction cannot be started.
|
||||
// Remember to commit or rollback the transaction when done.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// tx, err := pgm.BeginTx(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer tx.Rollback(ctx) // rollback on error
|
||||
//
|
||||
// // ... do work ...
|
||||
//
|
||||
// return tx.Commit(ctx)
|
||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||
tx, err := poolPGX.Load().Begin(ctx)
|
||||
if err != nil {
|
||||
@@ -113,32 +145,43 @@ func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
// IsNotFound error check
|
||||
// IsNotFound checks if an error is a "no rows" error from pgx.
|
||||
// Returns true if the error indicates no rows were found in a query result.
|
||||
func IsNotFound(err error) bool {
|
||||
return errors.Is(err, pgx.ErrNoRows)
|
||||
}
|
||||
|
||||
// PgTime as in UTC
|
||||
// PgTime converts a Go time.Time to PostgreSQL timestamptz type.
|
||||
// The time is stored as-is (preserves timezone information).
|
||||
func PgTime(t time.Time) pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: t, Valid: true}
|
||||
}
|
||||
|
||||
// PgTimeNow returns the current time as PostgreSQL timestamptz type.
|
||||
func PgTimeNow() pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||
}
|
||||
|
||||
// TsAndQuery converts a text search query to use AND operator between terms.
|
||||
// Example: "hello world" becomes "hello & world"
|
||||
func TsAndQuery(q string) string {
|
||||
return strings.Join(strings.Fields(q), " & ")
|
||||
}
|
||||
|
||||
// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching.
|
||||
// Example: "hello world" becomes "hello:* & world:*"
|
||||
func TsPrefixAndQuery(q string) string {
|
||||
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
||||
}
|
||||
|
||||
// TsOrQuery converts a text search query to use OR operator between terms.
|
||||
// Example: "hello world" becomes "hello | world"
|
||||
func TsOrQuery(q string) string {
|
||||
return strings.Join(strings.Fields(q), " | ")
|
||||
}
|
||||
|
||||
// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching.
|
||||
// Example: "hello world" becomes "hello:* | world:*"
|
||||
func TsPrefixOrQuery(q string) string {
|
||||
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
||||
}
|
||||
|
||||
19
pgm_field.go
19
pgm_field.go
@@ -45,7 +45,12 @@ func escapeSQLString(s string) string {
|
||||
}
|
||||
|
||||
func (f Field) Name() string {
|
||||
return strings.Split(string(f), ".")[1]
|
||||
s := string(f)
|
||||
idx := strings.LastIndexByte(s, '.')
|
||||
if idx == -1 {
|
||||
return s // Return as-is if no dot
|
||||
}
|
||||
return s[idx+1:] // Return part after last dot
|
||||
}
|
||||
|
||||
func (f Field) String() string {
|
||||
@@ -66,19 +71,19 @@ func ConcatWs(sep string, fields ...Field) Field {
|
||||
}
|
||||
|
||||
// StringAgg creates a STRING_AGG SQL function.
|
||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
||||
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAgg(exp, sep string) Field {
|
||||
func StringAgg(field Field, sep string) Field {
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(" + exp + ",'" + escapedSep + "')")
|
||||
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
||||
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAggCast(exp, sep string) Field {
|
||||
func StringAggCast(field Field, sep string) Field {
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')")
|
||||
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringEscape will wrap field with:
|
||||
|
||||
@@ -102,29 +102,29 @@ func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exp string
|
||||
field Field
|
||||
sep string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe parameters",
|
||||
exp: "column_name",
|
||||
field: Field("column_name"),
|
||||
sep: ", ",
|
||||
contains: "string_agg(column_name,', ')",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes in separator",
|
||||
sep: "'; DROP TABLE users; --",
|
||||
exp: "col",
|
||||
field: Field("col"),
|
||||
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StringAgg(tt.exp, tt.sep)
|
||||
result := StringAgg(tt.field, tt.sep)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
|
||||
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
|
||||
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||
result := StringAggCast("column", "'; DROP TABLE")
|
||||
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
||||
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||
if string(result) != expected {
|
||||
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||
@@ -336,3 +336,47 @@ func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFieldName tests Field.Name() method
|
||||
func TestFieldName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field Field
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "field with table prefix",
|
||||
field: Field("users.email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field with schema and table prefix",
|
||||
field: Field("public.users.email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field without dot",
|
||||
field: Field("email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field with multiple dots",
|
||||
field: Field("schema.table.column"),
|
||||
expected: "column",
|
||||
},
|
||||
{
|
||||
name: "aggregate function",
|
||||
field: Field("COUNT(users.id)"),
|
||||
expected: "id)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.field.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
706
pgm_test.go
Normal file
706
pgm_test.go
Normal file
@@ -0,0 +1,706 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
|
||||
func TestGetPoolNotInitialized(t *testing.T) {
|
||||
// Save current pool state
|
||||
currentPool := poolPGX.Load()
|
||||
|
||||
// Clear the pool to simulate uninitialized state
|
||||
poolPGX.Store(nil)
|
||||
|
||||
// Restore pool after test
|
||||
defer func() {
|
||||
poolPGX.Store(currentPool)
|
||||
}()
|
||||
|
||||
// Verify panic occurs
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("GetPool should panic when pool not initialized")
|
||||
} else {
|
||||
// Check panic message
|
||||
msg, ok := r.(string)
|
||||
if !ok || !strings.Contains(msg, "not initialized") {
|
||||
t.Errorf("Expected panic message about initialization, got: %v", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
GetPool() // Should panic
|
||||
}
|
||||
|
||||
// TestConfigValidation tests that InitPool validates configuration
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantPanic bool
|
||||
panicMsg string
|
||||
}{
|
||||
{
|
||||
name: "empty connection string",
|
||||
config: Config{
|
||||
ConnString: "",
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "connection string is empty",
|
||||
},
|
||||
{
|
||||
name: "MinConns greater than MaxConns",
|
||||
config: Config{
|
||||
ConnString: "postgres://localhost/test",
|
||||
MaxConns: 5,
|
||||
MinConns: 10,
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "MinConns",
|
||||
},
|
||||
{
|
||||
name: "negative MaxConns",
|
||||
config: Config{
|
||||
ConnString: "postgres://localhost/test",
|
||||
MaxConns: -1,
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "negative",
|
||||
},
|
||||
{
|
||||
name: "negative MinConns",
|
||||
config: Config{
|
||||
ConnString: "postgres://localhost/test",
|
||||
MinConns: -1,
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "negative",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if tt.wantPanic && r == nil {
|
||||
t.Errorf("InitPool should panic for %s", tt.name)
|
||||
}
|
||||
if !tt.wantPanic && r != nil {
|
||||
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
|
||||
}
|
||||
if tt.wantPanic && r != nil {
|
||||
msg := ""
|
||||
switch v := r.(type) {
|
||||
case string:
|
||||
msg = v
|
||||
case error:
|
||||
msg = v.Error()
|
||||
}
|
||||
if !strings.Contains(msg, tt.panicMsg) {
|
||||
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
InitPool(tt.config)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
|
||||
func TestClosePoolIdempotent(t *testing.T) {
|
||||
// Save current pool
|
||||
currentPool := poolPGX.Load()
|
||||
defer func() {
|
||||
poolPGX.Store(currentPool)
|
||||
}()
|
||||
|
||||
// Set to nil
|
||||
poolPGX.Store(nil)
|
||||
|
||||
// Should not panic when called on nil pool
|
||||
ClosePool()
|
||||
ClosePool()
|
||||
ClosePool()
|
||||
|
||||
// Should work fine - no panic expected
|
||||
}
|
||||
|
||||
// TestIsNotFound tests the error checking utility
|
||||
func TestIsNotFound(t *testing.T) {
|
||||
// Test with pgx.ErrNoRows
|
||||
if !IsNotFound(pgx.ErrNoRows) {
|
||||
t.Error("IsNotFound should return true for pgx.ErrNoRows")
|
||||
}
|
||||
|
||||
// Test with other errors
|
||||
otherErr := errors.New("some other error")
|
||||
if IsNotFound(otherErr) {
|
||||
t.Error("IsNotFound should return false for non-ErrNoRows errors")
|
||||
}
|
||||
|
||||
// Test with nil
|
||||
if IsNotFound(nil) {
|
||||
t.Error("IsNotFound should return false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgTime tests PostgreSQL timestamp conversion
|
||||
func TestPgTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
pgTime := PgTime(now)
|
||||
|
||||
if !pgTime.Valid {
|
||||
t.Error("PgTime should return valid timestamp")
|
||||
}
|
||||
|
||||
if !pgTime.Time.Equal(now) {
|
||||
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgTimeNow tests current time conversion
|
||||
func TestPgTimeNow(t *testing.T) {
|
||||
before := time.Now()
|
||||
pgTime := PgTimeNow()
|
||||
after := time.Now()
|
||||
|
||||
if !pgTime.Valid {
|
||||
t.Error("PgTimeNow should return valid timestamp")
|
||||
}
|
||||
|
||||
// Check that time is between before and after
|
||||
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
|
||||
t.Error("PgTimeNow should return current time")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTsQueryFunctions tests full-text search query builders
|
||||
func TestTsQueryFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(string) string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "TsAndQuery",
|
||||
fn: TsAndQuery,
|
||||
input: "hello world",
|
||||
expected: "hello & world",
|
||||
},
|
||||
{
|
||||
name: "TsPrefixAndQuery",
|
||||
fn: TsPrefixAndQuery,
|
||||
input: "hello world",
|
||||
expected: "hello:* & world:*",
|
||||
},
|
||||
{
|
||||
name: "TsOrQuery",
|
||||
fn: TsOrQuery,
|
||||
input: "hello world",
|
||||
expected: "hello | world",
|
||||
},
|
||||
{
|
||||
name: "TsPrefixOrQuery",
|
||||
fn: TsPrefixOrQuery,
|
||||
input: "hello world",
|
||||
expected: "hello:* | world:*",
|
||||
},
|
||||
{
|
||||
name: "TsAndQuery with multiple words",
|
||||
fn: TsAndQuery,
|
||||
input: "one two three",
|
||||
expected: "one & two & three",
|
||||
},
|
||||
{
|
||||
name: "TsOrQuery with multiple words",
|
||||
fn: TsOrQuery,
|
||||
input: "one two three",
|
||||
expected: "one | two | three",
|
||||
},
|
||||
{
|
||||
name: "TsAndQuery with extra spaces",
|
||||
fn: TsAndQuery,
|
||||
input: "hello world",
|
||||
expected: "hello & world",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.fn(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFieldsWithSuffix tests the internal suffix function
|
||||
func TestFieldsWithSuffix(t *testing.T) {
|
||||
result := fieldsWithSufix("hello world", ":*")
|
||||
expected := []string{"hello:*", "world:*"}
|
||||
|
||||
if len(result) != len(expected) {
|
||||
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
|
||||
}
|
||||
|
||||
for i, v := range result {
|
||||
if v != expected[i] {
|
||||
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
|
||||
func TestBeginTxErrorWrapping(t *testing.T) {
|
||||
// This test verifies basic behavior without requiring a database connection
|
||||
// Actual error wrapping would require a failing database connection
|
||||
|
||||
// Skip if no pool initialized (unit test environment)
|
||||
if poolPGX.Load() == nil {
|
||||
t.Skip("Skipping BeginTx test - requires initialized pool")
|
||||
}
|
||||
|
||||
// Just verify the function accepts a context
|
||||
// In a real scenario with a cancelled context, it would return an error
|
||||
ctx := context.Background()
|
||||
|
||||
// We can't fully test without a real DB connection, so just document behavior
|
||||
_, err := BeginTx(ctx)
|
||||
if err != nil {
|
||||
// Error is expected in test environment without valid DB
|
||||
t.Logf("BeginTx returned error (expected in test env): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextCancellation tests that cancelled context is handled
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
// Create an already-cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
// Note: This test would require an actual database connection to verify
|
||||
// that the context cancellation is properly propagated. We're testing
|
||||
// that the context is accepted by the function signature.
|
||||
|
||||
// The actual behavior would be tested in integration tests
|
||||
// For now, we just verify the function doesn't panic with cancelled context
|
||||
|
||||
// Skip if no pool initialized (unit test environment)
|
||||
if poolPGX.Load() == nil {
|
||||
t.Skip("Skipping context cancellation test - requires initialized pool")
|
||||
}
|
||||
|
||||
// This would return an error about cancelled context in real scenario
|
||||
_, err := BeginTx(ctx)
|
||||
if err == nil {
|
||||
t.Log("BeginTx with cancelled context - error expected in real scenario")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorTypes verifies exported error variables
|
||||
func TestErrorTypes(t *testing.T) {
|
||||
if ErrConnStringMissing == nil {
|
||||
t.Error("ErrConnStringMissing should be defined")
|
||||
}
|
||||
|
||||
if ErrInitTX == nil {
|
||||
t.Error("ErrInitTX should be defined")
|
||||
}
|
||||
|
||||
if ErrCommitTX == nil {
|
||||
t.Error("ErrCommitTX should be defined")
|
||||
}
|
||||
|
||||
if ErrNoRows == nil {
|
||||
t.Error("ErrNoRows should be defined")
|
||||
}
|
||||
|
||||
// Check error messages are descriptive
|
||||
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
|
||||
t.Error("ErrConnStringMissing should mention connection string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringPoolGetPut tests the string builder pool
|
||||
func TestStringPoolGetPut(t *testing.T) {
|
||||
// Get a string builder from pool
|
||||
sb1 := getSB()
|
||||
if sb1 == nil {
|
||||
t.Fatal("getSB should return non-nil string builder")
|
||||
}
|
||||
|
||||
// Use it
|
||||
sb1.WriteString("test")
|
||||
if sb1.String() != "test" {
|
||||
t.Error("String builder should work normally")
|
||||
}
|
||||
|
||||
// Put it back
|
||||
putSB(sb1)
|
||||
|
||||
// Get another one - should be reset
|
||||
sb2 := getSB()
|
||||
if sb2 == nil {
|
||||
t.Fatal("getSB should return non-nil string builder")
|
||||
}
|
||||
|
||||
// Should be empty after reset
|
||||
if sb2.Len() != 0 {
|
||||
t.Error("String builder from pool should be reset")
|
||||
}
|
||||
|
||||
putSB(sb2)
|
||||
}
|
||||
|
||||
// TestConditionActionTypes tests condition action enum
|
||||
func TestConditionActionTypes(t *testing.T) {
|
||||
// Verify enum values are distinct
|
||||
actions := []CondAction{
|
||||
CondActionNothing,
|
||||
CondActionNeedToClose,
|
||||
CondActionSubQuery,
|
||||
}
|
||||
|
||||
seen := make(map[CondAction]bool)
|
||||
for _, action := range actions {
|
||||
if seen[action] {
|
||||
t.Errorf("Duplicate CondAction value: %v", action)
|
||||
}
|
||||
seen[action] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryBuilderReuse documents query builder reuse behavior
|
||||
// This test verifies that query builders accumulate state and should NOT be reused
|
||||
func TestQueryBuilderReuse(t *testing.T) {
|
||||
t.Skip("Query builders are not designed for reuse - this test documents the limitation")
|
||||
|
||||
// This test would require actual database tables to demonstrate the issue
|
||||
// The behavior is:
|
||||
// 1. Creating a base query builder
|
||||
// 2. Adding a WHERE clause
|
||||
// 3. Reusing the same builder with another WHERE clause
|
||||
// 4. The second query would have BOTH WHERE clauses
|
||||
//
|
||||
// Example:
|
||||
// baseQuery := users.User.Select(user.ID, user.Email)
|
||||
// baseQuery.Where(user.ID.Eq(1)) // First query
|
||||
// baseQuery.Where(user.Status.Eq(2)) // Second query has BOTH conditions!
|
||||
//
|
||||
// This is by design - query builders are mutable and single-use.
|
||||
// Each query should create a new builder instance.
|
||||
}
|
||||
|
||||
// TestQueryBuilderThreadSafety documents that query builders are not thread-safe
|
||||
func TestQueryBuilderThreadSafety(t *testing.T) {
|
||||
t.Skip("Query builders are not thread-safe by design - this test documents the limitation")
|
||||
|
||||
// Query builders accumulate state in their internal fields (where, args, etc.)
|
||||
// and do not use any synchronization primitives.
|
||||
//
|
||||
// Sharing a query builder across goroutines will cause race conditions.
|
||||
//
|
||||
// CORRECT usage: Create a new query builder in each goroutine
|
||||
// INCORRECT usage: Share a query builder variable across goroutines
|
||||
//
|
||||
// The connection pool itself IS thread-safe and can be used concurrently.
|
||||
}
|
||||
|
||||
// TestSelectQueryBuilderBasics tests basic query building
|
||||
func TestSelectQueryBuilderBasics(t *testing.T) {
|
||||
// Create a test table
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
emailField := Field("users.email")
|
||||
|
||||
// Test basic select
|
||||
qry := testTable.Select(idField, emailField)
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "SELECT") {
|
||||
t.Error("Query should contain SELECT")
|
||||
}
|
||||
if !strings.Contains(sql, "users.id") {
|
||||
t.Error("Query should contain users.id field")
|
||||
}
|
||||
if !strings.Contains(sql, "users.email") {
|
||||
t.Error("Query should contain users.email field")
|
||||
}
|
||||
if !strings.Contains(sql, "FROM users") {
|
||||
t.Error("Query should contain FROM users")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
|
||||
func TestWhereConditionAccumulation(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
statusField := Field("users.status")
|
||||
|
||||
// Create a query and add multiple WHERE conditions
|
||||
qry := testTable.Select(idField).
|
||||
Where(idField.Eq(1)).
|
||||
Where(statusField.Eq("active"))
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
// Both conditions should be present
|
||||
if !strings.Contains(sql, "WHERE") {
|
||||
t.Error("Query should contain WHERE clause")
|
||||
}
|
||||
|
||||
// Should have multiple conditions (this demonstrates accumulation)
|
||||
if strings.Count(sql, "$") < 2 {
|
||||
t.Error("Query should have multiple parameterized values")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInsertQueryBuilder tests insert query building
|
||||
func TestInsertQueryBuilder(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
|
||||
qry := testTable.Insert()
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "INSERT INTO users") {
|
||||
t.Error("Query should contain INSERT INTO users")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateQueryBuilder tests update query building
|
||||
func TestUpdateQueryBuilder(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
|
||||
qry := testTable.Update().
|
||||
Where(idField.Eq(1))
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "UPDATE users") {
|
||||
t.Error("Query should contain UPDATE users")
|
||||
}
|
||||
if !strings.Contains(sql, "WHERE") {
|
||||
t.Error("Query should contain WHERE clause")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteQueryBuilder tests delete query building
|
||||
func TestDeleteQueryBuilder(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
|
||||
qry := testTable.Delete().Where(idField.Eq(1))
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "DELETE FROM users") {
|
||||
t.Error("Query should contain DELETE FROM users")
|
||||
}
|
||||
if !strings.Contains(sql, "WHERE") {
|
||||
t.Error("Query should contain WHERE clause")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFieldConditioners tests various field condition methods
|
||||
func TestFieldConditioners(t *testing.T) {
|
||||
field := Field("users.age")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition Conditioner
|
||||
checkFunc func(string) bool
|
||||
}{
|
||||
{
|
||||
name: "Eq",
|
||||
condition: field.Eq(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
|
||||
},
|
||||
{
|
||||
name: "NotEq",
|
||||
condition: field.NotEq(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
|
||||
},
|
||||
{
|
||||
name: "Gt",
|
||||
condition: field.Gt(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
|
||||
},
|
||||
{
|
||||
name: "Lt",
|
||||
condition: field.Lt(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
|
||||
},
|
||||
{
|
||||
name: "Gte",
|
||||
condition: field.Gte(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
|
||||
},
|
||||
{
|
||||
name: "Lte",
|
||||
condition: field.Lte(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
|
||||
},
|
||||
{
|
||||
name: "IsNull",
|
||||
condition: field.IsNull(),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
|
||||
},
|
||||
{
|
||||
name: "IsNotNull",
|
||||
condition: field.IsNotNull(),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Build a simple query to see the condition
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
qry := testTable.Select(field).Where(tt.condition)
|
||||
sql := qry.String()
|
||||
|
||||
if !tt.checkFunc(sql) {
|
||||
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFieldFunctions tests SQL function wrappers
|
||||
func TestFieldFunctions(t *testing.T) {
|
||||
field := Field("users.count")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
result Field
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "Count",
|
||||
result: field.Count(),
|
||||
contains: "COUNT(",
|
||||
},
|
||||
{
|
||||
name: "Sum",
|
||||
result: field.Sum(),
|
||||
contains: "SUM(",
|
||||
},
|
||||
{
|
||||
name: "Avg",
|
||||
result: field.Avg(),
|
||||
contains: "AVG(",
|
||||
},
|
||||
{
|
||||
name: "Max",
|
||||
result: field.Max(),
|
||||
contains: "MAX(",
|
||||
},
|
||||
{
|
||||
name: "Min",
|
||||
result: field.Min(),
|
||||
contains: "Min(",
|
||||
},
|
||||
{
|
||||
name: "Lower",
|
||||
result: field.Lower(),
|
||||
contains: "LOWER(",
|
||||
},
|
||||
{
|
||||
name: "Upper",
|
||||
result: field.Upper(),
|
||||
contains: "UPPER(",
|
||||
},
|
||||
{
|
||||
name: "Trim",
|
||||
result: field.Trim(),
|
||||
contains: "TRIM(",
|
||||
},
|
||||
{
|
||||
name: "Asc",
|
||||
result: field.Asc(),
|
||||
contains: " ASC",
|
||||
},
|
||||
{
|
||||
name: "Desc",
|
||||
result: field.Desc(),
|
||||
contains: " DESC",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resultStr := string(tt.result)
|
||||
if !strings.Contains(resultStr, tt.contains) {
|
||||
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInsertQueryValidation tests that Insert queries validate fields are set
|
||||
func TestInsertQueryValidation(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
|
||||
t.Run("Exec without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert()
|
||||
err := qry.Exec(context.Background())
|
||||
if err == nil {
|
||||
t.Error("Insert.Exec() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExecTx without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert()
|
||||
// We can't test ExecTx without a real transaction, but we can test the error is defined
|
||||
err := qry.ExecTx(context.Background(), nil)
|
||||
if err == nil {
|
||||
t.Error("Insert.ExecTx() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("First without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert().Returning(idField)
|
||||
var id int
|
||||
err := qry.First(context.Background(), &id)
|
||||
if err == nil {
|
||||
t.Error("Insert.First() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FirstTx without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert().Returning(idField)
|
||||
var id int
|
||||
err := qry.FirstTx(context.Background(), nil, &id)
|
||||
if err == nil {
|
||||
t.Error("Insert.FirstTx() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func TestInsertQuery2(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkInsertQuery-12 2014519 584.0 ns/op 1144 B/op 18 allocs/op
|
||||
// BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op
|
||||
func BenchmarkInsertQuery(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Insert().
|
||||
@@ -50,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op
|
||||
// BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op
|
||||
func BenchmarkInsertSetMap(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Insert().
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestUpdateQueryValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op
|
||||
// BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op
|
||||
func BenchmarkUpdateQuery(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Update().
|
||||
|
||||
@@ -5,7 +5,7 @@ package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -84,21 +84,28 @@ func (q *insertQry) DoNothing() Execute {
|
||||
}
|
||||
|
||||
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
||||
var sb strings.Builder
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
|
||||
sb.WriteString("DO UPDATE SET ")
|
||||
for i, f := range fields {
|
||||
col := f.Name()
|
||||
if i == 0 {
|
||||
fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col)
|
||||
} else {
|
||||
fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col)
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(col)
|
||||
sb.WriteString(" = EXCLUDED.")
|
||||
sb.WriteString(col)
|
||||
}
|
||||
|
||||
q.conflictAction = "DO UPDATE SET " + sb.String()
|
||||
q.conflictAction = sb.String()
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *insertQry) Exec(ctx context.Context) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before Exec()")
|
||||
}
|
||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -107,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before ExecTx()")
|
||||
}
|
||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -116,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||
}
|
||||
|
||||
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before First()")
|
||||
}
|
||||
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||
}
|
||||
|
||||
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before FirstTx()")
|
||||
}
|
||||
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||
}
|
||||
|
||||
|
||||
@@ -203,7 +203,8 @@ func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, co
|
||||
defer putSB(sb)
|
||||
sb.Grow(len(str) * 2)
|
||||
|
||||
sb.WriteString(str + " AND ")
|
||||
sb.WriteString(str)
|
||||
sb.WriteString(" AND ")
|
||||
|
||||
var argIdx int
|
||||
for i, c := range cond {
|
||||
@@ -263,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||
|
||||
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
@@ -274,13 +279,21 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors from iteration
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||
rows, err := tx.Query(ctx, q.String(), q.args...)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -290,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors from iteration
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user