forked from go/pgm
perf enhancement
This commit is contained in:
50
README.md
50
README.md
@@ -172,6 +172,56 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Important: Query Builder Lifecycle
|
||||||
|
|
||||||
|
⚠️ **Query builders are single-use and should not be reused:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ WRONG - Don't reuse query builders
|
||||||
|
baseQuery := users.User.Select(users.ID, users.Email)
|
||||||
|
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||||
|
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||||
|
// Second query has BOTH WHERE clauses - incorrect behavior!
|
||||||
|
|
||||||
|
// ✅ CORRECT - Create new query each time
|
||||||
|
users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||||
|
users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up.
|
||||||
|
|
||||||
|
### Thread Safety
|
||||||
|
|
||||||
|
⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ CORRECT - Each goroutine creates its own query
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
var email string
|
||||||
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq(id)).
|
||||||
|
First(ctx, &email)
|
||||||
|
// Process result...
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ❌ WRONG - Sharing query builder across goroutines
|
||||||
|
baseQuery := users.User.Select(users.Email)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
var email string
|
||||||
|
baseQuery.Where(users.ID.Eq(id)).First(ctx, &email)
|
||||||
|
// RACE CONDITION! Multiple goroutines modifying shared state
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Thread-Safe Components:**
|
||||||
|
- ✅ Connection Pool - Safe for concurrent use
|
||||||
|
- ✅ Table objects - Safe to share
|
||||||
|
- ❌ Query builders - Create new instance per goroutine
|
||||||
|
|
||||||
## Usage Examples
|
## Usage Examples
|
||||||
|
|
||||||
### SELECT Queries
|
### SELECT Queries
|
||||||
|
|||||||
57
pgm.go
57
pgm.go
@@ -25,13 +25,15 @@ var (
|
|||||||
ErrConnStringMissing = errors.New("connection string is empty")
|
ErrConnStringMissing = errors.New("connection string is empty")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Errors
|
// Common errors returned by pgm operations
|
||||||
var (
|
var (
|
||||||
ErrInitTX = errors.New("failed to init db.tx")
|
ErrInitTX = errors.New("failed to init db.tx")
|
||||||
ErrCommitTX = errors.New("failed to commit db.tx")
|
ErrCommitTX = errors.New("failed to commit db.tx")
|
||||||
ErrNoRows = errors.New("no data found")
|
ErrNoRows = errors.New("no data found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Config holds the configuration for initializing the connection pool.
|
||||||
|
// All fields except ConnString are optional and will use pgx defaults if not set.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ConnString string
|
ConnString string
|
||||||
MaxConns int32
|
MaxConns int32
|
||||||
@@ -40,7 +42,17 @@ type Config struct {
|
|||||||
MaxConnIdleTime time.Duration
|
MaxConnIdleTime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitPool will create new pgxpool.Pool and will keep it for its working
|
// InitPool initializes the connection pool with the provided configuration.
|
||||||
|
// It validates the configuration and panics if invalid.
|
||||||
|
// This function should be called once at application startup.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// pgm.InitPool(pgm.Config{
|
||||||
|
// ConnString: "postgres://user:pass@localhost/dbname",
|
||||||
|
// MaxConns: 100,
|
||||||
|
// MinConns: 5,
|
||||||
|
// })
|
||||||
func InitPool(conf Config) {
|
func InitPool(conf Config) {
|
||||||
if conf.ConnString == "" {
|
if conf.ConnString == "" {
|
||||||
panic(ErrConnStringMissing)
|
panic(ErrConnStringMissing)
|
||||||
@@ -88,9 +100,15 @@ func InitPool(conf Config) {
|
|||||||
poolPGX.Store(p)
|
poolPGX.Store(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPool instance
|
// GetPool returns the initialized connection pool instance.
|
||||||
|
// It panics with a descriptive message if InitPool() has not been called.
|
||||||
|
// This is a fail-fast approach to catch programming errors early.
|
||||||
func GetPool() *pgxpool.Pool {
|
func GetPool() *pgxpool.Pool {
|
||||||
return poolPGX.Load()
|
p := poolPGX.Load()
|
||||||
|
if p == nil {
|
||||||
|
panic("pgm: connection pool not initialized, call InitPool() first")
|
||||||
|
}
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClosePool closes the connection pool gracefully.
|
// ClosePool closes the connection pool gracefully.
|
||||||
@@ -102,7 +120,21 @@ func ClosePool() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeginTx begins a pgx poll transaction
|
// BeginTx begins a new database transaction from the connection pool.
|
||||||
|
// Returns an error if the transaction cannot be started.
|
||||||
|
// Remember to commit or rollback the transaction when done.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// tx, err := pgm.BeginTx(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// defer tx.Rollback(ctx) // rollback on error
|
||||||
|
//
|
||||||
|
// // ... do work ...
|
||||||
|
//
|
||||||
|
// return tx.Commit(ctx)
|
||||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||||
tx, err := poolPGX.Load().Begin(ctx)
|
tx, err := poolPGX.Load().Begin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -113,32 +145,43 @@ func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
|||||||
return tx, nil
|
return tx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsNotFound error check
|
// IsNotFound checks if an error is a "no rows" error from pgx.
|
||||||
|
// Returns true if the error indicates no rows were found in a query result.
|
||||||
func IsNotFound(err error) bool {
|
func IsNotFound(err error) bool {
|
||||||
return errors.Is(err, pgx.ErrNoRows)
|
return errors.Is(err, pgx.ErrNoRows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PgTime as in UTC
|
// PgTime converts a Go time.Time to PostgreSQL timestamptz type.
|
||||||
|
// The time is stored as-is (preserves timezone information).
|
||||||
func PgTime(t time.Time) pgtype.Timestamptz {
|
func PgTime(t time.Time) pgtype.Timestamptz {
|
||||||
return pgtype.Timestamptz{Time: t, Valid: true}
|
return pgtype.Timestamptz{Time: t, Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PgTimeNow returns the current time as PostgreSQL timestamptz type.
|
||||||
func PgTimeNow() pgtype.Timestamptz {
|
func PgTimeNow() pgtype.Timestamptz {
|
||||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TsAndQuery converts a text search query to use AND operator between terms.
|
||||||
|
// Example: "hello world" becomes "hello & world"
|
||||||
func TsAndQuery(q string) string {
|
func TsAndQuery(q string) string {
|
||||||
return strings.Join(strings.Fields(q), " & ")
|
return strings.Join(strings.Fields(q), " & ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching.
|
||||||
|
// Example: "hello world" becomes "hello:* & world:*"
|
||||||
func TsPrefixAndQuery(q string) string {
|
func TsPrefixAndQuery(q string) string {
|
||||||
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TsOrQuery converts a text search query to use OR operator between terms.
|
||||||
|
// Example: "hello world" becomes "hello | world"
|
||||||
func TsOrQuery(q string) string {
|
func TsOrQuery(q string) string {
|
||||||
return strings.Join(strings.Fields(q), " | ")
|
return strings.Join(strings.Fields(q), " | ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching.
|
||||||
|
// Example: "hello world" becomes "hello:* | world:*"
|
||||||
func TsPrefixOrQuery(q string) string {
|
func TsPrefixOrQuery(q string) string {
|
||||||
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
||||||
}
|
}
|
||||||
|
|||||||
19
pgm_field.go
19
pgm_field.go
@@ -45,7 +45,12 @@ func escapeSQLString(s string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f Field) Name() string {
|
func (f Field) Name() string {
|
||||||
return strings.Split(string(f), ".")[1]
|
s := string(f)
|
||||||
|
idx := strings.LastIndexByte(s, '.')
|
||||||
|
if idx == -1 {
|
||||||
|
return s // Return as-is if no dot
|
||||||
|
}
|
||||||
|
return s[idx+1:] // Return part after last dot
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f Field) String() string {
|
func (f Field) String() string {
|
||||||
@@ -66,19 +71,19 @@ func ConcatWs(sep string, fields ...Field) Field {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StringAgg creates a STRING_AGG SQL function.
|
// StringAgg creates a STRING_AGG SQL function.
|
||||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||||
func StringAgg(exp, sep string) Field {
|
func StringAgg(field Field, sep string) Field {
|
||||||
escapedSep := escapeSQLString(sep)
|
escapedSep := escapeSQLString(sep)
|
||||||
return Field("string_agg(" + exp + ",'" + escapedSep + "')")
|
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||||
func StringAggCast(exp, sep string) Field {
|
func StringAggCast(field Field, sep string) Field {
|
||||||
escapedSep := escapeSQLString(sep)
|
escapedSep := escapeSQLString(sep)
|
||||||
return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')")
|
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringEscape will wrap field with:
|
// StringEscape will wrap field with:
|
||||||
|
|||||||
@@ -102,29 +102,29 @@ func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
|||||||
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
exp string
|
field Field
|
||||||
sep string
|
sep string
|
||||||
contains string
|
contains string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "safe parameters",
|
name: "safe parameters",
|
||||||
exp: "column_name",
|
field: Field("column_name"),
|
||||||
sep: ", ",
|
sep: ", ",
|
||||||
contains: "string_agg(column_name,', ')",
|
contains: "string_agg(column_name,', ')",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "escaped quotes in separator",
|
name: "escaped quotes in separator",
|
||||||
sep: "'; DROP TABLE users; --",
|
sep: "'; DROP TABLE users; --",
|
||||||
exp: "col",
|
field: Field("col"),
|
||||||
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := StringAgg(tt.exp, tt.sep)
|
result := StringAgg(tt.field, tt.sep)
|
||||||
if !strings.Contains(string(result), tt.contains) {
|
if !strings.Contains(string(result), tt.contains) {
|
||||||
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
|
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
|||||||
|
|
||||||
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||||
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||||
result := StringAggCast("column", "'; DROP TABLE")
|
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
||||||
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||||
if string(result) != expected {
|
if string(result) != expected {
|
||||||
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||||
@@ -336,3 +336,47 @@ func TestSQLInjectionAttackVectors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestFieldName tests Field.Name() method
|
||||||
|
func TestFieldName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field Field
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "field with table prefix",
|
||||||
|
field: Field("users.email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field with schema and table prefix",
|
||||||
|
field: Field("public.users.email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field without dot",
|
||||||
|
field: Field("email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field with multiple dots",
|
||||||
|
field: Field("schema.table.column"),
|
||||||
|
expected: "column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "aggregate function",
|
||||||
|
field: Field("COUNT(users.id)"),
|
||||||
|
expected: "id)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.field.Name()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
706
pgm_test.go
Normal file
706
pgm_test.go
Normal file
@@ -0,0 +1,706 @@
|
|||||||
|
// Patial Tech.
|
||||||
|
// Author, Ankit Patial
|
||||||
|
|
||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
|
||||||
|
func TestGetPoolNotInitialized(t *testing.T) {
|
||||||
|
// Save current pool state
|
||||||
|
currentPool := poolPGX.Load()
|
||||||
|
|
||||||
|
// Clear the pool to simulate uninitialized state
|
||||||
|
poolPGX.Store(nil)
|
||||||
|
|
||||||
|
// Restore pool after test
|
||||||
|
defer func() {
|
||||||
|
poolPGX.Store(currentPool)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Verify panic occurs
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Error("GetPool should panic when pool not initialized")
|
||||||
|
} else {
|
||||||
|
// Check panic message
|
||||||
|
msg, ok := r.(string)
|
||||||
|
if !ok || !strings.Contains(msg, "not initialized") {
|
||||||
|
t.Errorf("Expected panic message about initialization, got: %v", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
GetPool() // Should panic
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConfigValidation tests that InitPool validates configuration
|
||||||
|
func TestConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantPanic bool
|
||||||
|
panicMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty connection string",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "",
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "connection string is empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinConns greater than MaxConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MaxConns: 5,
|
||||||
|
MinConns: 10,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "MinConns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative MaxConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MaxConns: -1,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "negative",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative MinConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MinConns: -1,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "negative",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if tt.wantPanic && r == nil {
|
||||||
|
t.Errorf("InitPool should panic for %s", tt.name)
|
||||||
|
}
|
||||||
|
if !tt.wantPanic && r != nil {
|
||||||
|
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
|
||||||
|
}
|
||||||
|
if tt.wantPanic && r != nil {
|
||||||
|
msg := ""
|
||||||
|
switch v := r.(type) {
|
||||||
|
case string:
|
||||||
|
msg = v
|
||||||
|
case error:
|
||||||
|
msg = v.Error()
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, tt.panicMsg) {
|
||||||
|
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
InitPool(tt.config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
|
||||||
|
func TestClosePoolIdempotent(t *testing.T) {
|
||||||
|
// Save current pool
|
||||||
|
currentPool := poolPGX.Load()
|
||||||
|
defer func() {
|
||||||
|
poolPGX.Store(currentPool)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set to nil
|
||||||
|
poolPGX.Store(nil)
|
||||||
|
|
||||||
|
// Should not panic when called on nil pool
|
||||||
|
ClosePool()
|
||||||
|
ClosePool()
|
||||||
|
ClosePool()
|
||||||
|
|
||||||
|
// Should work fine - no panic expected
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsNotFound tests the error checking utility
|
||||||
|
func TestIsNotFound(t *testing.T) {
|
||||||
|
// Test with pgx.ErrNoRows
|
||||||
|
if !IsNotFound(pgx.ErrNoRows) {
|
||||||
|
t.Error("IsNotFound should return true for pgx.ErrNoRows")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with other errors
|
||||||
|
otherErr := errors.New("some other error")
|
||||||
|
if IsNotFound(otherErr) {
|
||||||
|
t.Error("IsNotFound should return false for non-ErrNoRows errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with nil
|
||||||
|
if IsNotFound(nil) {
|
||||||
|
t.Error("IsNotFound should return false for nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgTime tests PostgreSQL timestamp conversion
|
||||||
|
func TestPgTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
pgTime := PgTime(now)
|
||||||
|
|
||||||
|
if !pgTime.Valid {
|
||||||
|
t.Error("PgTime should return valid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pgTime.Time.Equal(now) {
|
||||||
|
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgTimeNow tests current time conversion
|
||||||
|
func TestPgTimeNow(t *testing.T) {
|
||||||
|
before := time.Now()
|
||||||
|
pgTime := PgTimeNow()
|
||||||
|
after := time.Now()
|
||||||
|
|
||||||
|
if !pgTime.Valid {
|
||||||
|
t.Error("PgTimeNow should return valid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that time is between before and after
|
||||||
|
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
|
||||||
|
t.Error("PgTimeNow should return current time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTsQueryFunctions tests full-text search query builders
|
||||||
|
func TestTsQueryFunctions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(string) string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TsAndQuery",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello & world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsPrefixAndQuery",
|
||||||
|
fn: TsPrefixAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello:* & world:*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsOrQuery",
|
||||||
|
fn: TsOrQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello | world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsPrefixOrQuery",
|
||||||
|
fn: TsPrefixOrQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello:* | world:*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsAndQuery with multiple words",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "one two three",
|
||||||
|
expected: "one & two & three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsOrQuery with multiple words",
|
||||||
|
fn: TsOrQuery,
|
||||||
|
input: "one two three",
|
||||||
|
expected: "one | two | three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsAndQuery with extra spaces",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello & world",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.fn(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldsWithSuffix tests the internal suffix function
|
||||||
|
func TestFieldsWithSuffix(t *testing.T) {
|
||||||
|
result := fieldsWithSufix("hello world", ":*")
|
||||||
|
expected := []string{"hello:*", "world:*"}
|
||||||
|
|
||||||
|
if len(result) != len(expected) {
|
||||||
|
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range result {
|
||||||
|
if v != expected[i] {
|
||||||
|
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
|
||||||
|
func TestBeginTxErrorWrapping(t *testing.T) {
|
||||||
|
// This test verifies basic behavior without requiring a database connection
|
||||||
|
// Actual error wrapping would require a failing database connection
|
||||||
|
|
||||||
|
// Skip if no pool initialized (unit test environment)
|
||||||
|
if poolPGX.Load() == nil {
|
||||||
|
t.Skip("Skipping BeginTx test - requires initialized pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just verify the function accepts a context
|
||||||
|
// In a real scenario with a cancelled context, it would return an error
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// We can't fully test without a real DB connection, so just document behavior
|
||||||
|
_, err := BeginTx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Error is expected in test environment without valid DB
|
||||||
|
t.Logf("BeginTx returned error (expected in test env): %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContextCancellation tests that cancelled context is handled
|
||||||
|
func TestContextCancellation(t *testing.T) {
|
||||||
|
// Create an already-cancelled context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
// Note: This test would require an actual database connection to verify
|
||||||
|
// that the context cancellation is properly propagated. We're testing
|
||||||
|
// that the context is accepted by the function signature.
|
||||||
|
|
||||||
|
// The actual behavior would be tested in integration tests
|
||||||
|
// For now, we just verify the function doesn't panic with cancelled context
|
||||||
|
|
||||||
|
// Skip if no pool initialized (unit test environment)
|
||||||
|
if poolPGX.Load() == nil {
|
||||||
|
t.Skip("Skipping context cancellation test - requires initialized pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// This would return an error about cancelled context in real scenario
|
||||||
|
_, err := BeginTx(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Log("BeginTx with cancelled context - error expected in real scenario")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestErrorTypes verifies exported error variables
|
||||||
|
func TestErrorTypes(t *testing.T) {
|
||||||
|
if ErrConnStringMissing == nil {
|
||||||
|
t.Error("ErrConnStringMissing should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrInitTX == nil {
|
||||||
|
t.Error("ErrInitTX should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrCommitTX == nil {
|
||||||
|
t.Error("ErrCommitTX should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrNoRows == nil {
|
||||||
|
t.Error("ErrNoRows should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error messages are descriptive
|
||||||
|
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
|
||||||
|
t.Error("ErrConnStringMissing should mention connection string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringPoolGetPut tests the string builder pool
|
||||||
|
func TestStringPoolGetPut(t *testing.T) {
|
||||||
|
// Get a string builder from pool
|
||||||
|
sb1 := getSB()
|
||||||
|
if sb1 == nil {
|
||||||
|
t.Fatal("getSB should return non-nil string builder")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use it
|
||||||
|
sb1.WriteString("test")
|
||||||
|
if sb1.String() != "test" {
|
||||||
|
t.Error("String builder should work normally")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put it back
|
||||||
|
putSB(sb1)
|
||||||
|
|
||||||
|
// Get another one - should be reset
|
||||||
|
sb2 := getSB()
|
||||||
|
if sb2 == nil {
|
||||||
|
t.Fatal("getSB should return non-nil string builder")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be empty after reset
|
||||||
|
if sb2.Len() != 0 {
|
||||||
|
t.Error("String builder from pool should be reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
putSB(sb2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConditionActionTypes tests condition action enum
|
||||||
|
func TestConditionActionTypes(t *testing.T) {
|
||||||
|
// Verify enum values are distinct
|
||||||
|
actions := []CondAction{
|
||||||
|
CondActionNothing,
|
||||||
|
CondActionNeedToClose,
|
||||||
|
CondActionSubQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[CondAction]bool)
|
||||||
|
for _, action := range actions {
|
||||||
|
if seen[action] {
|
||||||
|
t.Errorf("Duplicate CondAction value: %v", action)
|
||||||
|
}
|
||||||
|
seen[action] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestQueryBuilderReuse documents query builder reuse behavior
|
||||||
|
// This test verifies that query builders accumulate state and should NOT be reused
|
||||||
|
func TestQueryBuilderReuse(t *testing.T) {
|
||||||
|
t.Skip("Query builders are not designed for reuse - this test documents the limitation")
|
||||||
|
|
||||||
|
// This test would require actual database tables to demonstrate the issue
|
||||||
|
// The behavior is:
|
||||||
|
// 1. Creating a base query builder
|
||||||
|
// 2. Adding a WHERE clause
|
||||||
|
// 3. Reusing the same builder with another WHERE clause
|
||||||
|
// 4. The second query would have BOTH WHERE clauses
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// baseQuery := users.User.Select(user.ID, user.Email)
|
||||||
|
// baseQuery.Where(user.ID.Eq(1)) // First query
|
||||||
|
// baseQuery.Where(user.Status.Eq(2)) // Second query has BOTH conditions!
|
||||||
|
//
|
||||||
|
// This is by design - query builders are mutable and single-use.
|
||||||
|
// Each query should create a new builder instance.
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestQueryBuilderThreadSafety documents that query builders are not thread-safe
|
||||||
|
func TestQueryBuilderThreadSafety(t *testing.T) {
|
||||||
|
t.Skip("Query builders are not thread-safe by design - this test documents the limitation")
|
||||||
|
|
||||||
|
// Query builders accumulate state in their internal fields (where, args, etc.)
|
||||||
|
// and do not use any synchronization primitives.
|
||||||
|
//
|
||||||
|
// Sharing a query builder across goroutines will cause race conditions.
|
||||||
|
//
|
||||||
|
// CORRECT usage: Create a new query builder in each goroutine
|
||||||
|
// INCORRECT usage: Share a query builder variable across goroutines
|
||||||
|
//
|
||||||
|
// The connection pool itself IS thread-safe and can be used concurrently.
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectQueryBuilderBasics tests basic query building
|
||||||
|
func TestSelectQueryBuilderBasics(t *testing.T) {
|
||||||
|
// Create a test table
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
emailField := Field("users.email")
|
||||||
|
|
||||||
|
// Test basic select
|
||||||
|
qry := testTable.Select(idField, emailField)
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "SELECT") {
|
||||||
|
t.Error("Query should contain SELECT")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "users.id") {
|
||||||
|
t.Error("Query should contain users.id field")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "users.email") {
|
||||||
|
t.Error("Query should contain users.email field")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "FROM users") {
|
||||||
|
t.Error("Query should contain FROM users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
|
||||||
|
func TestWhereConditionAccumulation(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
statusField := Field("users.status")
|
||||||
|
|
||||||
|
// Create a query and add multiple WHERE conditions
|
||||||
|
qry := testTable.Select(idField).
|
||||||
|
Where(idField.Eq(1)).
|
||||||
|
Where(statusField.Eq("active"))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
// Both conditions should be present
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have multiple conditions (this demonstrates accumulation)
|
||||||
|
if strings.Count(sql, "$") < 2 {
|
||||||
|
t.Error("Query should have multiple parameterized values")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInsertQueryBuilder tests insert query building
|
||||||
|
func TestInsertQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
|
||||||
|
qry := testTable.Insert()
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "INSERT INTO users") {
|
||||||
|
t.Error("Query should contain INSERT INTO users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateQueryBuilder tests update query building
|
||||||
|
func TestUpdateQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
qry := testTable.Update().
|
||||||
|
Where(idField.Eq(1))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "UPDATE users") {
|
||||||
|
t.Error("Query should contain UPDATE users")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteQueryBuilder tests delete query building
|
||||||
|
func TestDeleteQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
qry := testTable.Delete().Where(idField.Eq(1))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "DELETE FROM users") {
|
||||||
|
t.Error("Query should contain DELETE FROM users")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldConditioners tests various field condition methods
|
||||||
|
func TestFieldConditioners(t *testing.T) {
|
||||||
|
field := Field("users.age")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
condition Conditioner
|
||||||
|
checkFunc func(string) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Eq",
|
||||||
|
condition: field.Eq(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NotEq",
|
||||||
|
condition: field.NotEq(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gt",
|
||||||
|
condition: field.Gt(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lt",
|
||||||
|
condition: field.Lt(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gte",
|
||||||
|
condition: field.Gte(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lte",
|
||||||
|
condition: field.Lte(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IsNull",
|
||||||
|
condition: field.IsNull(),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IsNotNull",
|
||||||
|
condition: field.IsNotNull(),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Build a simple query to see the condition
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
qry := testTable.Select(field).Where(tt.condition)
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !tt.checkFunc(sql) {
|
||||||
|
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldFunctions tests SQL function wrappers
|
||||||
|
func TestFieldFunctions(t *testing.T) {
|
||||||
|
field := Field("users.count")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result Field
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Count",
|
||||||
|
result: field.Count(),
|
||||||
|
contains: "COUNT(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Sum",
|
||||||
|
result: field.Sum(),
|
||||||
|
contains: "SUM(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Avg",
|
||||||
|
result: field.Avg(),
|
||||||
|
contains: "AVG(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max",
|
||||||
|
result: field.Max(),
|
||||||
|
contains: "MAX(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Min",
|
||||||
|
result: field.Min(),
|
||||||
|
contains: "Min(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lower",
|
||||||
|
result: field.Lower(),
|
||||||
|
contains: "LOWER(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Upper",
|
||||||
|
result: field.Upper(),
|
||||||
|
contains: "UPPER(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Trim",
|
||||||
|
result: field.Trim(),
|
||||||
|
contains: "TRIM(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Asc",
|
||||||
|
result: field.Asc(),
|
||||||
|
contains: " ASC",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Desc",
|
||||||
|
result: field.Desc(),
|
||||||
|
contains: " DESC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resultStr := string(tt.result)
|
||||||
|
if !strings.Contains(resultStr, tt.contains) {
|
||||||
|
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInsertQueryValidation tests that Insert queries validate fields are set
|
||||||
|
func TestInsertQueryValidation(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
t.Run("Exec without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert()
|
||||||
|
err := qry.Exec(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.Exec() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecTx without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert()
|
||||||
|
// We can't test ExecTx without a real transaction, but we can test the error is defined
|
||||||
|
err := qry.ExecTx(context.Background(), nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.ExecTx() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("First without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert().Returning(idField)
|
||||||
|
var id int
|
||||||
|
err := qry.First(context.Background(), &id)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.First() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FirstTx without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert().Returning(idField)
|
||||||
|
var id int
|
||||||
|
err := qry.FirstTx(context.Background(), nil, &id)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.FirstTx() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -36,7 +36,7 @@ func TestInsertQuery2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkInsertQuery-12 2014519 584.0 ns/op 1144 B/op 18 allocs/op
|
// BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op
|
||||||
func BenchmarkInsertQuery(b *testing.B) {
|
func BenchmarkInsertQuery(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Insert().
|
_ = db.User.Insert().
|
||||||
@@ -50,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op
|
// BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op
|
||||||
func BenchmarkInsertSetMap(b *testing.B) {
|
func BenchmarkInsertSetMap(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Insert().
|
_ = db.User.Insert().
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func TestUpdateQueryValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op
|
// BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op
|
||||||
func BenchmarkUpdateQuery(b *testing.B) {
|
func BenchmarkUpdateQuery(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Update().
|
_ = db.User.Update().
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ package pgm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -84,21 +84,28 @@ func (q *insertQry) DoNothing() Execute {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
||||||
var sb strings.Builder
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
|
||||||
|
sb.WriteString("DO UPDATE SET ")
|
||||||
for i, f := range fields {
|
for i, f := range fields {
|
||||||
col := f.Name()
|
col := f.Name()
|
||||||
if i == 0 {
|
if i > 0 {
|
||||||
fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col)
|
sb.WriteString(", ")
|
||||||
} else {
|
|
||||||
fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col)
|
|
||||||
}
|
}
|
||||||
|
sb.WriteString(col)
|
||||||
|
sb.WriteString(" = EXCLUDED.")
|
||||||
|
sb.WriteString(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
q.conflictAction = "DO UPDATE SET " + sb.String()
|
q.conflictAction = sb.String()
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) Exec(ctx context.Context) error {
|
func (q *insertQry) Exec(ctx context.Context) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before Exec()")
|
||||||
|
}
|
||||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -107,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before ExecTx()")
|
||||||
|
}
|
||||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -116,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before First()")
|
||||||
|
}
|
||||||
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before FirstTx()")
|
||||||
|
}
|
||||||
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -203,7 +203,8 @@ func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, co
|
|||||||
defer putSB(sb)
|
defer putSB(sb)
|
||||||
sb.Grow(len(str) * 2)
|
sb.Grow(len(str) * 2)
|
||||||
|
|
||||||
sb.WriteString(str + " AND ")
|
sb.WriteString(str)
|
||||||
|
sb.WriteString(" AND ")
|
||||||
|
|
||||||
var argIdx int
|
var argIdx int
|
||||||
for i, c := range cond {
|
for i, c := range cond {
|
||||||
@@ -263,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
|||||||
|
|
||||||
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||||
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if err != nil {
|
||||||
return ErrNoRows
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -274,13 +279,21 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for errors from iteration
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||||
rows, err := tx.Query(ctx, q.String(), q.args...)
|
rows, err := tx.Query(ctx, q.String(), q.args...)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if err != nil {
|
||||||
return ErrNoRows
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@@ -290,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for errors from iteration
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user