forked from go/pgm
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c2cf7ff088 | |||
| cc7e6b7b3f | |||
| 48f1d1952e | |||
| 1d9d9d9308 | |||
| 29cddb6389 | |||
| 551e2123bc | |||
| a2b984c342 | |||
| a795c0e8d6 | |||
| bb6a45732f | |||
| 2551e07b3e | |||
| 9837fb1e37 | |||
| 12d6fface6 | |||
| 325103e8ef | |||
| 6f5748d3d3 | |||
| b25f9367ed | |||
| 8750f3ad95 | |||
| ad1faf2056 | |||
| 525c64e678 | |||
| 5f0fdadb8b | |||
| 68263895f7 | |||
| ee6cb445ab | |||
| d07c25fe01 | |||
| 096480a3eb | |||
| 6c14441591 | |||
| d95eea6636 | |||
| 63b71692b5 | |||
| 36e4145365 | |||
| 2ec328059f | |||
| f700f3e891 | |||
| f5350292fc |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -25,4 +25,4 @@ go.work.sum
|
|||||||
# env file
|
# env file
|
||||||
.env
|
.env
|
||||||
|
|
||||||
example/local_*
|
playground/local_*
|
||||||
|
|||||||
25
Makefile
25
Makefile
@@ -1,4 +1,25 @@
|
|||||||
|
.PHONY: run bench-select test build install
|
||||||
|
|
||||||
|
# Version can be set via: make build VERSION=v1.2.3
|
||||||
|
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
|
|
||||||
run:
|
run:
|
||||||
go run ./cmd -o ./example/db ./example/schema.sql
|
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||||
|
|
||||||
bench-select:
|
bench-select:
|
||||||
go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./playground
|
||||||
|
|
||||||
|
# Build with version information
|
||||||
|
build:
|
||||||
|
go build -ldflags "-X main.version=$(VERSION)" -o pgm ./cmd
|
||||||
|
|
||||||
|
# Install to GOPATH/bin with version
|
||||||
|
install:
|
||||||
|
go install -ldflags "-X main.version=$(VERSION)" ./cmd
|
||||||
|
|
||||||
|
# Show current version
|
||||||
|
version:
|
||||||
|
@echo "Version: $(VERSION)"
|
||||||
|
|||||||
795
README.md
795
README.md
@@ -1,5 +1,794 @@
|
|||||||
# pgm (Postgres Mapper)
|
# pgm - PostgreSQL Query Mapper
|
||||||
|
|
||||||
Simple query builder to work with Go:PG apps.
|
[](https://pkg.go.dev/code.patial.tech/go/pgm)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
Will work along side with [dbmate](https://github.com/amacneil/dbmate), will consume schema.sql file created by dbmate
|
A lightweight, type-safe PostgreSQL query builder for Go, built on top of [jackc/pgx](https://github.com/jackc/pgx). **pgm** generates Go code from your SQL schema, enabling you to write SQL queries with compile-time safety and autocompletion support.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Type-safe queries** - Column and table names are validated at compile time
|
||||||
|
- **Zero reflection** - Fast performance with no runtime reflection overhead
|
||||||
|
- **SQL schema-based** - Generate Go code directly from your SQL schema files
|
||||||
|
- **Fluent API** - Intuitive query builder with method chaining
|
||||||
|
- **Transaction support** - First-class support for pgx transactions
|
||||||
|
- **Full-text search** - Built-in PostgreSQL full-text search helpers
|
||||||
|
- **Connection pooling** - Leverages pgx connection pool for optimal performance
|
||||||
|
- **Minimal code generation** - Only generates what you need, no bloat
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Why pgm?](#why-pgm)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Usage Examples](#usage-examples)
|
||||||
|
- [SELECT Queries](#select-queries)
|
||||||
|
- [INSERT Queries](#insert-queries)
|
||||||
|
- [UPDATE Queries](#update-queries)
|
||||||
|
- [DELETE Queries](#delete-queries)
|
||||||
|
- [Joins](#joins)
|
||||||
|
- [Transactions](#transactions)
|
||||||
|
- [Full-Text Search](#full-text-search)
|
||||||
|
- [CLI Tool](#cli-tool)
|
||||||
|
- [API Documentation](#api-documentation)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
## Why pgm?
|
||||||
|
|
||||||
|
### The Problem with Existing ORMs
|
||||||
|
|
||||||
|
While Go has excellent ORMs like [ent](https://github.com/ent/ent) and [sqlc](https://github.com/sqlc-dev/sqlc), they come with tradeoffs:
|
||||||
|
|
||||||
|
**ent** - Feature-rich but heavy:
|
||||||
|
- Generates extensive code for features you may never use
|
||||||
|
- Significantly increases binary size
|
||||||
|
- Complex schema definition in Go instead of SQL
|
||||||
|
- Auto-migrations can obscure actual database schema
|
||||||
|
|
||||||
|
**sqlc** - Great tool, but:
|
||||||
|
- Creates separate database models, forcing model mapping
|
||||||
|
- Query results require their own generated types
|
||||||
|
- Less flexibility in dynamic query building
|
||||||
|
|
||||||
|
### The pgm Approach
|
||||||
|
|
||||||
|
**pgm** takes a hybrid approach:
|
||||||
|
|
||||||
|
✅ **Schema as SQL** - Define your database schema in pure SQL, where it belongs
|
||||||
|
✅ **Minimal generation** - Only generates table and column definitions
|
||||||
|
✅ **Your models** - Use your own application models, no forced abstractions
|
||||||
|
✅ **Type safety** - Catch schema changes at compile time
|
||||||
|
✅ **SQL power** - Full control over your queries with a fluent API
|
||||||
|
✅ **Migration-friendly** - Use mature tools like [dbmate](https://github.com/amacneil/dbmate) for migrations
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get code.patial.tech/go/pgm
|
||||||
|
```
|
||||||
|
|
||||||
|
Install the CLI tool for schema code generation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go install code.patial.tech/go/pgm/cmd@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Building from Source
|
||||||
|
|
||||||
|
Build with automatic version detection (uses git tags):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build with version from git tags
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Build with specific version
|
||||||
|
make build VERSION=v1.2.3
|
||||||
|
|
||||||
|
# Install to GOPATH/bin
|
||||||
|
make install
|
||||||
|
|
||||||
|
# Or build manually with version
|
||||||
|
go build -ldflags "-X main.version=v1.2.3" -o pgm ./cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
Check the version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pgm -version
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Create Your Schema
|
||||||
|
|
||||||
|
Create a SQL schema file `schema.sql` or use the one created by [dbmate](https://github.com/amacneil/dbmate):
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
email VARCHAR(255) UNIQUE NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE posts (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id),
|
||||||
|
title VARCHAR(500) NOT NULL,
|
||||||
|
content TEXT,
|
||||||
|
published BOOLEAN DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Generate Go Code
|
||||||
|
|
||||||
|
Run the pgm CLI tool:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pgm -o ./db ./schema.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
This generates Go files for each table in `./db/`:
|
||||||
|
- `db/users/users.go` - Table and column definitions for users
|
||||||
|
- `db/posts/posts.go` - Table and column definitions for posts
|
||||||
|
|
||||||
|
### 3. Use in Your Code
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"code.patial.tech/go/pgm"
|
||||||
|
"yourapp/db/users"
|
||||||
|
"yourapp/db/posts"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Initialize connection pool
|
||||||
|
pgm.InitPool(pgm.Config{
|
||||||
|
ConnString: "postgres://user:pass@localhost:5432/dbname",
|
||||||
|
MaxConns: 25,
|
||||||
|
MinConns: 5,
|
||||||
|
})
|
||||||
|
defer pgm.ClosePool() // Ensure graceful shutdown
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Query a user
|
||||||
|
var email string
|
||||||
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
First(ctx, &email)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("User email: %s", email)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Important: Query Builder Lifecycle
|
||||||
|
|
||||||
|
### ✅ Conditional Building (CORRECT)
|
||||||
|
|
||||||
|
Query builders are **mutable by design** to support conditional query building:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ CORRECT - Conditional building pattern
|
||||||
|
query := users.User.Select(users.ID, users.Email, users.Name)
|
||||||
|
|
||||||
|
// Add conditions based on filters
|
||||||
|
if nameFilter != "" {
|
||||||
|
query = query.Where(users.Name.Like("%" + nameFilter + "%"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusFilter > 0 {
|
||||||
|
query = query.Where(users.Status.Eq(statusFilter))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sortByName {
|
||||||
|
query = query.OrderBy(users.Name.Asc())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the final query with all accumulated conditions
|
||||||
|
err := query.First(ctx, &id, &email, &name)
|
||||||
|
```
|
||||||
|
|
||||||
|
**This is the intended use!** The builder accumulates your conditions, which is powerful and flexible.
|
||||||
|
|
||||||
|
### ❌ Unintentional Reuse (INCORRECT)
|
||||||
|
|
||||||
|
Don't try to create a "base query" and reuse it for **multiple different queries**:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ WRONG - Trying to reuse for multiple separate queries
|
||||||
|
baseQuery := users.User.Select(users.ID, users.Email)
|
||||||
|
|
||||||
|
// First query - adds ID condition
|
||||||
|
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||||
|
|
||||||
|
// Second query - ALSO has ID=1 from above PLUS Status=2!
|
||||||
|
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||||
|
// This executes: WHERE users.id = 1 AND users.status = 2 (WRONG!)
|
||||||
|
|
||||||
|
// ✅ CORRECT - Each separate query gets its own builder
|
||||||
|
users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||||
|
users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up.
|
||||||
|
|
||||||
|
### Thread Safety
|
||||||
|
|
||||||
|
⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ CORRECT - Each goroutine creates its own query
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
var email string
|
||||||
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq(id)).
|
||||||
|
First(ctx, &email)
|
||||||
|
// Process result...
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ❌ WRONG - Sharing query builder across goroutines
|
||||||
|
baseQuery := users.User.Select(users.Email)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
var email string
|
||||||
|
baseQuery.Where(users.ID.Eq(id)).First(ctx, &email)
|
||||||
|
// RACE CONDITION! Multiple goroutines modifying shared state
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Thread-Safe Components:**
|
||||||
|
- ✅ Connection Pool - Safe for concurrent use
|
||||||
|
- ✅ Table objects - Safe to share
|
||||||
|
- ❌ Query builders - Create new instance per goroutine
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### SELECT Queries
|
||||||
|
|
||||||
|
#### Basic Select
|
||||||
|
|
||||||
|
```go
|
||||||
|
var user struct {
|
||||||
|
ID string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||||
|
Where(users.Email.Eq("john@example.com")).
|
||||||
|
First(ctx, &user.ID, &user.Email, &user.Name)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Select with Multiple Conditions
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Select(users.ID, users.Email).
|
||||||
|
Where(
|
||||||
|
users.Email.Like("john%"),
|
||||||
|
users.CreatedAt.Gt(time.Now().AddDate(0, -1, 0)),
|
||||||
|
).
|
||||||
|
OrderBy(users.CreatedAt.Desc()).
|
||||||
|
Limit(10).
|
||||||
|
First(ctx, &user.ID, &user.Email)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Select All with Callback
|
||||||
|
|
||||||
|
```go
|
||||||
|
var userList []User
|
||||||
|
|
||||||
|
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||||
|
Where(users.Name.Like("J%")).
|
||||||
|
OrderBy(users.Name.Asc()).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var u User
|
||||||
|
if err := row.Scan(&u.ID, &u.Email, &u.Name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userList = append(userList, u)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Pagination
|
||||||
|
|
||||||
|
```go
|
||||||
|
page := 2
|
||||||
|
pageSize := 20
|
||||||
|
|
||||||
|
err := users.User.Select(users.ID, users.Email).
|
||||||
|
OrderBy(users.CreatedAt.Desc()).
|
||||||
|
Limit(pageSize).
|
||||||
|
Offset((page - 1) * pageSize).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
// Process rows
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Grouping and Having
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Select(posts.UserID, pgm.Count(posts.ID)).
|
||||||
|
GroupBy(posts.UserID).
|
||||||
|
Having(pgm.Count(posts.ID).Gt(5)).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var userID string
|
||||||
|
var postCount int
|
||||||
|
return row.Scan(&userID, &postCount)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### INSERT Queries
|
||||||
|
|
||||||
|
#### Simple Insert
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
Set(users.CreatedAt, pgm.PgTimeNow()).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Insert with Map
|
||||||
|
|
||||||
|
```go
|
||||||
|
data := map[pgm.Field]any{
|
||||||
|
users.Email: "jane@example.com",
|
||||||
|
users.Name: "Jane Doe",
|
||||||
|
users.CreatedAt: pgm.PgTimeNow(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := users.User.Insert().
|
||||||
|
SetMap(data).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Insert with RETURNING
|
||||||
|
|
||||||
|
```go
|
||||||
|
var newID string
|
||||||
|
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
Returning(users.ID).
|
||||||
|
First(ctx, &newID)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Upsert (INSERT ... ON CONFLICT)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Do nothing on conflict
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
OnConflict(users.Email).
|
||||||
|
DoNothing().
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
// Update on conflict
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe Updated").
|
||||||
|
OnConflict(users.Email).
|
||||||
|
DoUpdate(users.Name).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### UPDATE Queries
|
||||||
|
|
||||||
|
#### Simple Update
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Update().
|
||||||
|
Set(users.Name, "John Smith").
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Update Multiple Fields
|
||||||
|
|
||||||
|
```go
|
||||||
|
updates := map[pgm.Field]any{
|
||||||
|
users.Name: "John Smith",
|
||||||
|
users.Email: "john.smith@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := users.User.Update().
|
||||||
|
SetMap(updates).
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Conditional Update
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Update().
|
||||||
|
Set(users.Name, "Updated Name").
|
||||||
|
Where(
|
||||||
|
users.Email.Like("%@example.com"),
|
||||||
|
users.CreatedAt.Lt(time.Now().AddDate(-1, 0, 0)),
|
||||||
|
).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### DELETE Queries
|
||||||
|
|
||||||
|
#### Simple Delete
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Delete().
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Conditional Delete
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Delete().
|
||||||
|
Where(
|
||||||
|
posts.Published.Eq(false),
|
||||||
|
posts.CreatedAt.Lt(time.Now().AddDate(0, 0, -30)),
|
||||||
|
).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Joins
|
||||||
|
|
||||||
|
#### Inner Join
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Select(posts.Title, users.Name).
|
||||||
|
Join(users.User, posts.UserID, users.ID).
|
||||||
|
Where(users.Email.Eq("john@example.com")).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var title, userName string
|
||||||
|
return row.Scan(&title, &userName)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Left Join
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Select(users.Name, posts.Title).
|
||||||
|
LeftJoin(posts.Post, users.ID, posts.UserID).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var userName, postTitle string
|
||||||
|
return row.Scan(&userName, &postTitle)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Join with Additional Conditions
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Select(posts.Title, users.Name).
|
||||||
|
Join(users.User, posts.UserID, users.ID, users.Email.Like("%@example.com")).
|
||||||
|
Where(posts.Published.Eq(true)).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
// Process rows
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transactions
|
||||||
|
|
||||||
|
#### Basic Transaction
|
||||||
|
|
||||||
|
```go
|
||||||
|
tx, err := pgm.BeginTx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
// Insert user
|
||||||
|
var userID string
|
||||||
|
err = users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
Returning(users.ID).
|
||||||
|
FirstTx(ctx, tx, &userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert post
|
||||||
|
err = posts.Post.Insert().
|
||||||
|
Set(posts.UserID, userID).
|
||||||
|
Set(posts.Title, "My First Post").
|
||||||
|
Set(posts.Content, "Hello, World!").
|
||||||
|
ExecTx(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit transaction
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Full-Text Search
|
||||||
|
|
||||||
|
PostgreSQL full-text search helpers:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Search with AND operator (all terms must match)
|
||||||
|
searchQuery := pgm.TsAndQuery("golang database")
|
||||||
|
// Result: "golang & database"
|
||||||
|
|
||||||
|
// Search with prefix matching
|
||||||
|
searchQuery := pgm.TsPrefixAndQuery("gol data")
|
||||||
|
// Result: "gol:* & data:*"
|
||||||
|
|
||||||
|
// Search with OR operator (any term matches)
|
||||||
|
searchQuery := pgm.TsOrQuery("golang rust")
|
||||||
|
// Result: "golang | rust"
|
||||||
|
|
||||||
|
// Prefix OR search
|
||||||
|
searchQuery := pgm.TsPrefixOrQuery("go ru")
|
||||||
|
// Result: "go:* | ru:*"
|
||||||
|
|
||||||
|
// Use in query (assuming you have a tsvector column)
|
||||||
|
err := posts.Post.Select(posts.Title, posts.Content).
|
||||||
|
Where(posts.SearchVector.Match(pgm.TsPrefixAndQuery(searchTerm))).
|
||||||
|
OrderBy(posts.CreatedAt.Desc()).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
// Process results
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI Tool
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pgm -o <output_directory> <schema.sql>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-o string Output directory path (required)
|
||||||
|
-version Show version information
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Generate from a single schema file
|
||||||
|
pgm -o ./db ./schema.sql
|
||||||
|
|
||||||
|
# Generate from concatenated migrations
|
||||||
|
cat migrations/*.sql > /tmp/schema.sql && pgm -o ./db /tmp/schema.sql
|
||||||
|
|
||||||
|
# Check version
|
||||||
|
pgm -version
|
||||||
|
```
|
||||||
|
|
||||||
|
### Known Limitations
|
||||||
|
|
||||||
|
The CLI tool uses a regex-based SQL parser with the following limitations:
|
||||||
|
|
||||||
|
- ❌ Multi-line comments `/* */` are not supported
|
||||||
|
- ❌ Complex data types (arrays, JSON, JSONB) may not parse correctly
|
||||||
|
- ❌ Quoted identifiers with special characters may fail
|
||||||
|
- ❌ Advanced PostgreSQL features (PARTITION BY, INHERITS) not supported
|
||||||
|
- ❌ Some constraints (CHECK, EXCLUDE) are not parsed
|
||||||
|
|
||||||
|
**Workarounds:**
|
||||||
|
- Use simple CREATE TABLE statements
|
||||||
|
- Avoid complex PostgreSQL-specific syntax in schema files
|
||||||
|
- Split complex schemas into multiple simple statements
|
||||||
|
- Remove comments before running the generator
|
||||||
|
|
||||||
|
For complex schemas, consider contributing a more robust parser or using a proper SQL parser library.
|
||||||
|
|
||||||
|
### Generated Code Structure
|
||||||
|
|
||||||
|
For a table named `users`, pgm generates:
|
||||||
|
|
||||||
|
```
|
||||||
|
db/
|
||||||
|
└── users/
|
||||||
|
└── users.go
|
||||||
|
```
|
||||||
|
|
||||||
|
The generated file contains:
|
||||||
|
- Generated code header with version and timestamp
|
||||||
|
- Table definition (`User`)
|
||||||
|
- Column field definitions (`ID`, `Email`, `Name`, etc.)
|
||||||
|
- Type-safe query builders (`Select()`, `Insert()`, `Update()`, `Delete()`)
|
||||||
|
|
||||||
|
**Example header:**
|
||||||
|
```go
|
||||||
|
// Code generated by code.patial.tech/go/pgm/cmd v1.2.3 on 2025-01-27 15:04:05 DO NOT EDIT.
|
||||||
|
```
|
||||||
|
|
||||||
|
The version in generated files helps track which version of the CLI tool was used, making it easier to identify when regeneration is needed after upgrades.
|
||||||
|
|
||||||
|
## API Documentation
|
||||||
|
|
||||||
|
### Connection Pool
|
||||||
|
|
||||||
|
#### InitPool
|
||||||
|
|
||||||
|
Initialize the connection pool (must be called once at startup):
|
||||||
|
|
||||||
|
```go
|
||||||
|
pgm.InitPool(pgm.Config{
|
||||||
|
ConnString: "postgres://...",
|
||||||
|
MaxConns: 25,
|
||||||
|
MinConns: 5,
|
||||||
|
MaxConnLifetime: time.Hour,
|
||||||
|
MaxConnIdleTime: time.Minute * 30,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Configuration Validation:**
|
||||||
|
- MinConns cannot be greater than MaxConns
|
||||||
|
- Connection counts cannot be negative
|
||||||
|
- Connection string is required
|
||||||
|
|
||||||
|
#### ClosePool
|
||||||
|
|
||||||
|
Close the connection pool gracefully (call during application shutdown):
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
pgm.InitPool(pgm.Config{
|
||||||
|
ConnString: "postgres://...",
|
||||||
|
})
|
||||||
|
defer pgm.ClosePool() // Ensures proper cleanup
|
||||||
|
|
||||||
|
// Your application code
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### GetPool
|
||||||
|
|
||||||
|
Get the underlying pgx pool:
|
||||||
|
|
||||||
|
```go
|
||||||
|
pool := pgm.GetPool()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Query Conditions
|
||||||
|
|
||||||
|
Available condition methods on fields:
|
||||||
|
|
||||||
|
- `Eq(value)` - Equal to
|
||||||
|
- `NotEq(value)` - Not equal to
|
||||||
|
- `Gt(value)` - Greater than
|
||||||
|
- `Gte(value)` - Greater than or equal to
|
||||||
|
- `Lt(value)` - Less than
|
||||||
|
- `Lte(value)` - Less than or equal to
|
||||||
|
- `Like(pattern)` - LIKE pattern match
|
||||||
|
- `ILike(pattern)` - Case-insensitive LIKE
|
||||||
|
- `In(values...)` - IN list
|
||||||
|
- `NotIn(values...)` - NOT IN list
|
||||||
|
- `IsNull()` - IS NULL
|
||||||
|
- `IsNotNull()` - IS NOT NULL
|
||||||
|
- `Between(start, end)` - BETWEEN range
|
||||||
|
|
||||||
|
### Field Methods
|
||||||
|
|
||||||
|
- `Asc()` - Sort ascending
|
||||||
|
- `Desc()` - Sort descending
|
||||||
|
- `Name()` - Get column name
|
||||||
|
- `String()` - Get fully qualified name (table.column)
|
||||||
|
|
||||||
|
### Utilities
|
||||||
|
|
||||||
|
- `pgm.PgTime(t time.Time)` - Convert Go time to PostgreSQL timestamptz
|
||||||
|
- `pgm.PgTimeNow()` - Current time as PostgreSQL timestamptz
|
||||||
|
- `pgm.IsNotFound(err)` - Check if error is "no rows found"
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use transactions for related operations** - Ensure data consistency
|
||||||
|
2. **Define schema in SQL** - Use migration tools like dbmate for schema management
|
||||||
|
3. **Regenerate after schema changes** - Run the CLI tool after any schema modifications
|
||||||
|
4. **Use your own models** - Don't let the database dictate your domain models
|
||||||
|
5. **Handle pgx.ErrNoRows** - Use `pgm.IsNotFound(err)` for cleaner error checking
|
||||||
|
6. **Always use context with timeouts** - Prevent queries from running indefinitely
|
||||||
|
7. **Validate UPDATE queries** - Ensure Set() is called before Exec()
|
||||||
|
8. **Be careful with DELETE** - Always use Where() unless you want to delete all rows
|
||||||
|
|
||||||
|
## Important Safety Notes
|
||||||
|
|
||||||
|
### ⚠️ DELETE Operations
|
||||||
|
|
||||||
|
DELETE without WHERE clause will delete ALL rows in the table:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ DANGEROUS - Deletes ALL rows!
|
||||||
|
users.User.Delete().Exec(ctx)
|
||||||
|
|
||||||
|
// ✅ SAFE - Deletes specific rows
|
||||||
|
users.User.Delete().Where(users.ID.Eq("user-id")).Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚠️ UPDATE Operations
|
||||||
|
|
||||||
|
UPDATE requires at least one Set() call:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ ERROR - No columns to update
|
||||||
|
users.User.Update().Where(users.ID.Eq(1)).Exec(ctx)
|
||||||
|
// Returns: "update query has no columns to update"
|
||||||
|
|
||||||
|
// ✅ CORRECT
|
||||||
|
users.User.Update().
|
||||||
|
Set(users.Name, "New Name").
|
||||||
|
Where(users.ID.Eq(1)).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚠️ Query Timeouts
|
||||||
|
|
||||||
|
Always use context with timeout to prevent hanging queries:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ RECOMMENDED
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq("some-id")).
|
||||||
|
First(ctx, &email)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚠️ Connection String Security
|
||||||
|
|
||||||
|
Never log or expose database connection strings as they contain credentials. The library does not sanitize connection strings in error messages.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
**pgm** is designed for performance:
|
||||||
|
|
||||||
|
- Zero reflection overhead
|
||||||
|
- Efficient string building with sync.Pool
|
||||||
|
- Leverages pgx's high-performance connection pooling
|
||||||
|
- Minimal allocations in query building
|
||||||
|
- Direct scanning into your types
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Go 1.20 or higher
|
||||||
|
- PostgreSQL 12 or higher
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|
||||||
|
## Author
|
||||||
|
|
||||||
|
|
||||||
|
**Ankit Patial** - [Patial Tech](https://code.patial.tech)
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
Built on top of the excellent [jackc/pgx](https://github.com/jackc/pgx) library.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Made with ❤️ for the Go community**
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import (
|
|||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// version can be set at build time using:
|
||||||
|
// go build -ldflags "-X main.version=v1.2.3" ./cmd
|
||||||
|
var version = "dev"
|
||||||
|
|
||||||
func generate(scheamPath, outDir string) error {
|
func generate(scheamPath, outDir string) error {
|
||||||
// read schame.sql
|
// read schame.sql
|
||||||
f, err := os.ReadFile(scheamPath)
|
f, err := os.ReadFile(scheamPath)
|
||||||
@@ -29,14 +33,16 @@ func generate(scheamPath, outDir string) error {
|
|||||||
|
|
||||||
// Output dir, create if not exists.
|
// Output dir, create if not exists.
|
||||||
if _, err := os.Stat(outDir); os.IsNotExist(err) {
|
if _, err := os.Stat(outDir); os.IsNotExist(err) {
|
||||||
if err := os.MkdirAll(outDir, 0740); err != nil {
|
if err := os.MkdirAll(outDir, 0755); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// schema.go will hold all tables info
|
// schema.go will hold all tables info
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("// Code generated by code.patial.tech/go/pgm/cmd\n// DO NOT EDIT.\n\n")
|
sb.WriteString(
|
||||||
|
fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s DO NOT EDIT.\n\n", GetVersionString()),
|
||||||
|
)
|
||||||
sb.WriteString(fmt.Sprintf("package %s \n", filepath.Base(outDir)))
|
sb.WriteString(fmt.Sprintf("package %s \n", filepath.Base(outDir)))
|
||||||
sb.WriteString(`
|
sb.WriteString(`
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
@@ -70,14 +76,24 @@ func generate(scheamPath, outDir string) error {
|
|||||||
sb.WriteString("}\n")
|
sb.WriteString("}\n")
|
||||||
}
|
}
|
||||||
modalDir = strings.ToLower(name)
|
modalDir = strings.ToLower(name)
|
||||||
os.Mkdir(filepath.Join(outDir, modalDir), 0740)
|
os.Mkdir(filepath.Join(outDir, modalDir), 0755)
|
||||||
|
|
||||||
if err = writeColFile(t.Table, t.Columns, filepath.Join(outDir, modalDir), caser); err != nil {
|
if err = writeColFile(t.Table, t.Columns, filepath.Join(outDir, modalDir), caser); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(")")
|
sb.WriteString(")")
|
||||||
|
|
||||||
|
sb.WriteString(`
|
||||||
|
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||||
|
t := pgm.Table{
|
||||||
|
Name: tblName,
|
||||||
|
DerivedTable: fromQry,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}`)
|
||||||
|
|
||||||
// Format code before saving
|
// Format code before saving
|
||||||
code, err := formatGoCode(sb.String())
|
code, err := formatGoCode(sb.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -85,16 +101,20 @@ func generate(scheamPath, outDir string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Save file to disk
|
// Save file to disk
|
||||||
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0640)
|
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0644)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Caser) error {
|
func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Caser) error {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("// Code generated by db-gen. DO NOT EDIT.\n\n")
|
sb.WriteString(
|
||||||
|
fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s DO NOT EDIT.\n\n", GetVersionString()),
|
||||||
|
)
|
||||||
sb.WriteString(fmt.Sprintf("package %s\n\n", filepath.Base(outDir)))
|
sb.WriteString(fmt.Sprintf("package %s\n\n", filepath.Base(outDir)))
|
||||||
sb.WriteString(fmt.Sprintf("import %q\n\n", "code.patial.tech/go/pgm"))
|
sb.WriteString(fmt.Sprintf("import %q\n\n", "code.patial.tech/go/pgm"))
|
||||||
sb.WriteString("const (")
|
sb.WriteString("const (")
|
||||||
|
sb.WriteString("\n // All fields in table " + tblName)
|
||||||
|
sb.WriteString(fmt.Sprintf("\n All pgm.Field = %q", tblName+".*"))
|
||||||
var name string
|
var name string
|
||||||
for _, c := range cols {
|
for _, c := range cols {
|
||||||
name = strings.ReplaceAll(c.Name, "_", " ")
|
name = strings.ReplaceAll(c.Name, "_", " ")
|
||||||
@@ -117,7 +137,7 @@ func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Cas
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Save file to disk.
|
// Save file to disk.
|
||||||
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0640)
|
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pluralToSingular converts plural table names to singular forms
|
// pluralToSingular converts plural table names to singular forms
|
||||||
|
|||||||
21
cmd/main.go
21
cmd/main.go
@@ -9,29 +9,42 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageTxt = `Please provide output director and input schema.
|
var (
|
||||||
|
showVersion bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const usageTxt = `Please provide output directory and input schema.
|
||||||
Example:
|
Example:
|
||||||
pgm/cmd -o ./db ./db/schema.sql
|
pgm -o ./db ./schema.sql
|
||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var outDir string
|
var outDir string
|
||||||
flag.StringVar(&outDir, "o", "", "-o as output directory path")
|
flag.StringVar(&outDir, "o", "", "-o as output directory path")
|
||||||
|
flag.BoolVar(&showVersion, "version", false, "show version information")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// Handle version flag
|
||||||
|
if showVersion {
|
||||||
|
fmt.Printf("pgm %s\n", GetVersionString())
|
||||||
|
fmt.Println("PostgreSQL Query Mapper - Schema code generator")
|
||||||
|
fmt.Println("https://code.patial.tech/go/pgm")
|
||||||
|
return
|
||||||
|
}
|
||||||
if len(os.Args) < 4 {
|
if len(os.Args) < 4 {
|
||||||
fmt.Print(usageTxt)
|
fmt.Print(usageTxt)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if outDir == "" {
|
if outDir == "" {
|
||||||
println("missing, -o output directory path")
|
fmt.Fprintln(os.Stderr, "Error: missing output directory path (-o flag required)")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := generate(os.Args[3], outDir); err != nil {
|
if err := generate(os.Args[3], outDir); err != nil {
|
||||||
println(err.Error())
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
15
cmd/parse.go
15
cmd/parse.go
@@ -1,5 +1,20 @@
|
|||||||
// Patial Tech.
|
// Patial Tech.
|
||||||
// Author, Ankit Patial
|
// Author, Ankit Patial
|
||||||
|
//
|
||||||
|
// SQL Parser Limitations:
|
||||||
|
// This is a simple regex-based SQL parser with the following known limitations:
|
||||||
|
// - No support for multi-line comments /* */
|
||||||
|
// - May struggle with complex data types (e.g., arrays, JSON, JSONB)
|
||||||
|
// - No handling of quoted identifiers with special characters
|
||||||
|
// - Advanced features like PARTITION BY, INHERITS not supported
|
||||||
|
// - Does not handle all PostgreSQL-specific syntax
|
||||||
|
// - Constraints like CHECK, EXCLUDE not parsed
|
||||||
|
//
|
||||||
|
// For complex schemas, consider:
|
||||||
|
// 1. Simplifying your schema for generation
|
||||||
|
// 2. Using multiple simple CREATE TABLE statements
|
||||||
|
// 3. Contributing a more robust parser implementation
|
||||||
|
// 4. Using a proper SQL parser library
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
|||||||
67
cmd/version.go
Normal file
67
cmd/version.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Patial Tech.
|
||||||
|
// Author, Ankit Patial
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Version returns the version of the pgm CLI tool.
|
||||||
|
// It tries to detect the version in the following order:
|
||||||
|
// 1. Build-time ldflags (set via: go build -ldflags "-X main.version=v1.2.3")
|
||||||
|
// 2. VCS information from build metadata (git tag/commit)
|
||||||
|
// 3. Falls back to "dev" if no version information is available
|
||||||
|
func Version() string {
|
||||||
|
// If version was set at build time via ldflags
|
||||||
|
if version != "" && version != "dev" {
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get version from build info (Go 1.18+)
|
||||||
|
if info, ok := debug.ReadBuildInfo(); ok {
|
||||||
|
// Check for version in main module
|
||||||
|
if info.Main.Version != "" && info.Main.Version != "(devel)" {
|
||||||
|
return info.Main.Version
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to extract from VCS information
|
||||||
|
var revision, modified string
|
||||||
|
for _, setting := range info.Settings {
|
||||||
|
switch setting.Key {
|
||||||
|
case "vcs.revision":
|
||||||
|
revision = setting.Value
|
||||||
|
case "vcs.modified":
|
||||||
|
modified = setting.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a git revision
|
||||||
|
if revision != "" {
|
||||||
|
// Shorten commit hash to 7 characters
|
||||||
|
if len(revision) > 7 {
|
||||||
|
revision = revision[:7]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add -dirty suffix if modified
|
||||||
|
if modified == "true" {
|
||||||
|
return "dev-" + revision + "-dirty"
|
||||||
|
}
|
||||||
|
return "dev-" + revision
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to dev
|
||||||
|
return "dev"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVersionString returns a formatted version string for display
|
||||||
|
func GetVersionString() string {
|
||||||
|
v := Version()
|
||||||
|
|
||||||
|
// Clean up version string for display
|
||||||
|
v = strings.TrimPrefix(v, "v")
|
||||||
|
|
||||||
|
return "v" + v
|
||||||
|
}
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
package example
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"code.patial.tech/go/pgm"
|
|
||||||
"code.patial.tech/go/pgm/example/db"
|
|
||||||
"code.patial.tech/go/pgm/example/db/branchuser"
|
|
||||||
"code.patial.tech/go/pgm/example/db/employee"
|
|
||||||
"code.patial.tech/go/pgm/example/db/user"
|
|
||||||
"code.patial.tech/go/pgm/example/db/usersession"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQryBuilder2(t *testing.T) {
|
|
||||||
got := db.User.Debug().Select(user.Email, user.FirstName).
|
|
||||||
Join(db.UserSession, user.ID, usersession.UserID).
|
|
||||||
Join(db.BranchUser, user.ID, branchuser.UserID).
|
|
||||||
Where(
|
|
||||||
user.ID.Eq(1),
|
|
||||||
pgm.Or(
|
|
||||||
user.StatusID.Eq(2),
|
|
||||||
user.UpdatedAt.Eq(3),
|
|
||||||
),
|
|
||||||
user.MfaKind.Eq(4),
|
|
||||||
pgm.Or(
|
|
||||||
user.FirstName.Eq(5),
|
|
||||||
user.MiddleName.Eq(6),
|
|
||||||
),
|
|
||||||
).
|
|
||||||
Where(
|
|
||||||
user.LastName.NEq(7),
|
|
||||||
user.Phone.Like("%123%"),
|
|
||||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
|
||||||
).
|
|
||||||
Limit(10).
|
|
||||||
Offset(100).
|
|
||||||
String()
|
|
||||||
|
|
||||||
expected := "SELECT users.email, users.first_name FROM users JOIN user_sessions ON users.id = user_sessions.user_id" +
|
|
||||||
" JOIN branch_users ON users.id = branch_users.user_id WHERE users.id = $1 AND (users.status_id = $2 OR users.updated_at = $3)" +
|
|
||||||
" AND users.mfa_kind = $4 AND (users.first_name = $5 OR users.middle_name = $6) AND users.last_name != $7 AND users.phone" +
|
|
||||||
" LIKE $8 AND users.email NOT IN(SELECT users.id FROM users WHERE users.id = $9) LIMIT 10 OFFSET 100"
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSelectWithHaving(t *testing.T) {
|
|
||||||
expected := "SELECT employees.department, AVG(employees.salary), COUNT(employees.id)" +
|
|
||||||
" FROM employees GROUP BY employees.department HAVING AVG(employees.salary) > $1 AND COUNT(employees.id) > $2"
|
|
||||||
got := db.Employee.
|
|
||||||
Select(employee.Department, employee.Salary.Avg(), employee.ID.Count()).
|
|
||||||
GroupBy(employee.Department).
|
|
||||||
Having(employee.Salary.Avg().Gt(50000), employee.ID.Count().Gt(5)).
|
|
||||||
String()
|
|
||||||
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkSelect-12 668817 1753 ns/op 4442 B/op 59 allocs/op
|
|
||||||
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
|
||||||
func BenchmarkSelect(b *testing.B) {
|
|
||||||
for b.Loop() {
|
|
||||||
_ = db.User.Select(user.Email, user.FirstName).
|
|
||||||
Join(db.UserSession, user.ID, usersession.UserID).
|
|
||||||
Join(db.BranchUser, user.ID, branchuser.UserID).
|
|
||||||
Where(
|
|
||||||
user.ID.Eq(1),
|
|
||||||
pgm.Or(
|
|
||||||
user.StatusID.Eq(2),
|
|
||||||
user.UpdatedAt.Eq(3),
|
|
||||||
),
|
|
||||||
user.MfaKind.Eq(4),
|
|
||||||
pgm.Or(
|
|
||||||
user.FirstName.Eq(5),
|
|
||||||
user.MiddleName.Eq(6),
|
|
||||||
),
|
|
||||||
).
|
|
||||||
Where(
|
|
||||||
user.LastName.NEq(7),
|
|
||||||
user.Phone.Like("%123%"),
|
|
||||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
|
||||||
).
|
|
||||||
Limit(10).
|
|
||||||
Offset(100).
|
|
||||||
String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
package example
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"code.patial.tech/go/pgm"
|
|
||||||
"code.patial.tech/go/pgm/example/db"
|
|
||||||
"code.patial.tech/go/pgm/example/db/user"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUpdateQuery(t *testing.T) {
|
|
||||||
got := db.User.Update().
|
|
||||||
Set(user.FirstName, "ankit").
|
|
||||||
Set(user.MiddleName, "singh").
|
|
||||||
Set(user.LastName, "patial").
|
|
||||||
Where(
|
|
||||||
user.Email.Eq("aa@aa.com"),
|
|
||||||
).
|
|
||||||
Where(
|
|
||||||
user.StatusID.NEq(1),
|
|
||||||
).
|
|
||||||
String()
|
|
||||||
|
|
||||||
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
|
||||||
if got != expected {
|
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateSetMap(t *testing.T) {
|
|
||||||
got := db.User.Update().
|
|
||||||
SetMap(map[pgm.Field]any{
|
|
||||||
user.FirstName: "ankit",
|
|
||||||
user.MiddleName: "singh",
|
|
||||||
user.LastName: "patial",
|
|
||||||
}).
|
|
||||||
Where(
|
|
||||||
user.Email.Eq("aa@aa.com"),
|
|
||||||
).
|
|
||||||
Where(
|
|
||||||
user.StatusID.NEq(1),
|
|
||||||
).
|
|
||||||
String()
|
|
||||||
|
|
||||||
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
|
||||||
if got != expected {
|
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op
|
|
||||||
func BenchmarkUpdateQuery(b *testing.B) {
|
|
||||||
for b.Loop() {
|
|
||||||
_ = db.User.Update().
|
|
||||||
Set(user.FirstName, "ankit").
|
|
||||||
Set(user.MiddleName, "singh").
|
|
||||||
Set(user.LastName, "patial").
|
|
||||||
Where(
|
|
||||||
user.Email.Eq("aa@aa.com"),
|
|
||||||
).
|
|
||||||
Where(
|
|
||||||
user.StatusID.NEq(1),
|
|
||||||
).
|
|
||||||
String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
13
go.mod
13
go.mod
@@ -3,19 +3,14 @@ module code.patial.tech/go/pgm
|
|||||||
go 1.24.5
|
go 1.24.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jackc/pgx v3.6.2+incompatible
|
github.com/jackc/pgx/v5 v5.7.6
|
||||||
golang.org/x/text v0.27.0
|
golang.org/x/text v0.31.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
golang.org/x/sync v0.16.0 // indirect
|
golang.org/x/crypto v0.45.0 // indirect
|
||||||
)
|
golang.org/x/sync v0.18.0 // indirect
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/jackc/pgx/v5 v5.7.5
|
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
|
||||||
golang.org/x/crypto v0.40.0 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
19
go.sum
19
go.sum
@@ -1,25 +1,36 @@
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o=
|
|
||||||
github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
|
|
||||||
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
||||||
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||||
|
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||||
|
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||||
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||||
|
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||||
|
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
255
pgm.go
255
pgm.go
@@ -1,134 +1,197 @@
|
|||||||
// Patial Tech.
|
// pgm
|
||||||
// Author, Ankit Patial
|
//
|
||||||
|
// A simple PG string query builder
|
||||||
|
//
|
||||||
|
// Author: Ankit Patial
|
||||||
|
|
||||||
package pgm
|
package pgm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Table in database
|
var (
|
||||||
type Table struct {
|
poolPGX atomic.Pointer[pgxpool.Pool]
|
||||||
Name string
|
ErrConnStringMissing = errors.New("connection string is empty")
|
||||||
PK []string
|
)
|
||||||
FieldCount uint16
|
|
||||||
debug bool
|
// Common errors returned by pgm operations
|
||||||
}
|
var (
|
||||||
|
ErrInitTX = errors.New("failed to init db.tx")
|
||||||
// Debug when set true will print generated query string in stdout
|
ErrCommitTX = errors.New("failed to commit db.tx")
|
||||||
func (t Table) Debug() Clause {
|
ErrNoRows = errors.New("no data found")
|
||||||
t.debug = true
|
)
|
||||||
return t
|
|
||||||
|
// Config holds the configuration for initializing the connection pool.
|
||||||
|
// All fields except ConnString are optional and will use pgx defaults if not set.
|
||||||
|
type Config struct {
|
||||||
|
ConnString string
|
||||||
|
MaxConns int32
|
||||||
|
MinConns int32
|
||||||
|
MaxConnLifetime time.Duration
|
||||||
|
MaxConnIdleTime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitPool initializes the connection pool with the provided configuration.
|
||||||
|
// It validates the configuration and panics if invalid.
|
||||||
|
// This function should be called once at application startup.
|
||||||
//
|
//
|
||||||
// Field ==>
|
// Example:
|
||||||
//
|
//
|
||||||
|
// pgm.InitPool(pgm.Config{
|
||||||
|
// ConnString: "postgres://user:pass@localhost/dbname",
|
||||||
|
// MaxConns: 100,
|
||||||
|
// MinConns: 5,
|
||||||
|
// })
|
||||||
|
func InitPool(conf Config) {
|
||||||
|
if conf.ConnString == "" {
|
||||||
|
panic(ErrConnStringMissing)
|
||||||
|
}
|
||||||
|
|
||||||
// Field related to a table
|
// Validate configuration
|
||||||
type Field string
|
if conf.MaxConns > 0 && conf.MinConns > 0 && conf.MinConns > conf.MaxConns {
|
||||||
|
panic(fmt.Errorf("MinConns (%d) cannot be greater than MaxConns (%d)", conf.MinConns, conf.MaxConns))
|
||||||
|
}
|
||||||
|
|
||||||
func (f Field) Name() string {
|
if conf.MaxConns < 0 || conf.MinConns < 0 {
|
||||||
return strings.Split(string(f), ".")[1]
|
panic(errors.New("connection pool configuration cannot have negative values"))
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := pgxpool.ParseConfig(conf.ConnString)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MaxConns > 0 {
|
||||||
|
cfg.MaxConns = conf.MaxConns // 100
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MinConns > 0 {
|
||||||
|
cfg.MinConns = conf.MinConns // 5
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MaxConnLifetime > 0 {
|
||||||
|
cfg.MaxConnLifetime = conf.MaxConnLifetime // time.Minute * 10
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MaxConnIdleTime > 0 {
|
||||||
|
cfg.MaxConnIdleTime = conf.MaxConnIdleTime // time.Minute * 5
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := pgxpool.NewWithConfig(context.Background(), cfg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = p.Ping(context.Background()); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
poolPGX.Store(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f Field) String() string {
|
// GetPool returns the initialized connection pool instance.
|
||||||
return string(f)
|
// It panics with a descriptive message if InitPool() has not been called.
|
||||||
|
// This is a fail-fast approach to catch programming errors early.
|
||||||
|
func GetPool() *pgxpool.Pool {
|
||||||
|
p := poolPGX.Load()
|
||||||
|
if p == nil {
|
||||||
|
panic("pgm: connection pool not initialized, call InitPool() first")
|
||||||
|
}
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count fn wrapping of field
|
// ClosePool closes the connection pool gracefully.
|
||||||
func (f Field) Count() Field {
|
// Should be called during application shutdown.
|
||||||
return Field("COUNT(" + f.String() + ")")
|
func ClosePool() {
|
||||||
}
|
if p := poolPGX.Load(); p != nil {
|
||||||
|
p.Close()
|
||||||
// Avg fn wrapping of field
|
poolPGX.Store(nil)
|
||||||
func (f Field) Avg() Field {
|
}
|
||||||
return Field("AVG(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Eq(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EqualFold will user LOWER() for comparision
|
|
||||||
func (f Field) EqFold(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) NEq(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Gt(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Gte(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Like(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) LikeFold(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ILIKE is case-insensitive
|
|
||||||
func (f Field) ILike(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) NotIn(val ...any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " NOT IN($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: qry, op: " NOT IN($)", action: CondActionSubQuery}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeginTx begins a new database transaction from the connection pool.
|
||||||
|
// Returns an error if the transaction cannot be started.
|
||||||
|
// Remember to commit or rollback the transaction when done.
|
||||||
//
|
//
|
||||||
// Helper func ==>
|
// Example:
|
||||||
//
|
//
|
||||||
|
// tx, err := pgm.BeginTx(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// defer tx.Rollback(ctx) // rollback on error
|
||||||
|
//
|
||||||
|
// // ... do work ...
|
||||||
|
//
|
||||||
|
// return tx.Commit(ctx)
|
||||||
|
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||||
|
tx, err := poolPGX.Load().Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to begin transaction", "error", err)
|
||||||
|
return nil, fmt.Errorf("failed to open db tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// PgTime as in UTC
|
return tx, nil
|
||||||
func PgTime(t time.Time) pgtype.Timestamptz {
|
|
||||||
return pgtype.Timestamptz{Time: t, Valid: true}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func PgTimeNow() pgtype.Timestamptz {
|
// IsNotFound checks if an error is a "no rows" error from pgx.
|
||||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
// Returns true if the error indicates no rows were found in a query result.
|
||||||
}
|
|
||||||
|
|
||||||
// IsNotFound error check
|
|
||||||
func IsNotFound(err error) bool {
|
func IsNotFound(err error) bool {
|
||||||
return errors.Is(err, pgx.ErrNoRows)
|
return errors.Is(err, pgx.ErrNoRows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConcatWs(sep string, fields ...Field) string {
|
// PgTime converts a Go time.Time to PostgreSQL timestamptz type.
|
||||||
return "concat_ws('" + sep + "'," + joinFileds(fields) + ")"
|
// The time is stored as-is (preserves timezone information).
|
||||||
|
func PgTime(t time.Time) pgtype.Timestamptz {
|
||||||
|
return pgtype.Timestamptz{Time: t, Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StringAgg(exp, sep string) string {
|
// PgTimeNow returns the current time as PostgreSQL timestamptz type.
|
||||||
return "string_agg(" + exp + ",'" + sep + "')"
|
func PgTimeNow() pgtype.Timestamptz {
|
||||||
|
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StringAggCast(exp, sep string) string {
|
// TsAndQuery converts a text search query to use AND operator between terms.
|
||||||
return "string_agg(cast(" + exp + " as varchar),'" + sep + "')"
|
// Example: "hello world" becomes "hello & world"
|
||||||
|
func TsAndQuery(q string) string {
|
||||||
|
return strings.Join(strings.Fields(q), " & ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching.
|
||||||
|
// Example: "hello world" becomes "hello:* & world:*"
|
||||||
|
func TsPrefixAndQuery(q string) string {
|
||||||
|
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TsOrQuery converts a text search query to use OR operator between terms.
|
||||||
|
// Example: "hello world" becomes "hello | world"
|
||||||
|
func TsOrQuery(q string) string {
|
||||||
|
return strings.Join(strings.Fields(q), " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching.
|
||||||
|
// Example: "hello world" becomes "hello:* | world:*"
|
||||||
|
func TsPrefixOrQuery(q string) string {
|
||||||
|
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldsWithSufix(v, sufix string) []string {
|
||||||
|
fields := strings.Fields(v)
|
||||||
|
prefixed := make([]string, len(fields))
|
||||||
|
for i, f := range fields {
|
||||||
|
prefixed[i] = f + sufix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefixed
|
||||||
}
|
}
|
||||||
|
|||||||
324
pgm_field.go
Normal file
324
pgm_field.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Field related to a table
|
||||||
|
type Field string
|
||||||
|
|
||||||
|
var (
|
||||||
|
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
|
||||||
|
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||||
|
|
||||||
|
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
|
||||||
|
validDateTruncLevels = map[string]bool{
|
||||||
|
"microseconds": true,
|
||||||
|
"milliseconds": true,
|
||||||
|
"second": true,
|
||||||
|
"minute": true,
|
||||||
|
"hour": true,
|
||||||
|
"day": true,
|
||||||
|
"week": true,
|
||||||
|
"month": true,
|
||||||
|
"quarter": true,
|
||||||
|
"year": true,
|
||||||
|
"decade": true,
|
||||||
|
"century": true,
|
||||||
|
"millennium": true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateSQLIdentifier checks if a string is a valid SQL identifier
|
||||||
|
func validateSQLIdentifier(s string) error {
|
||||||
|
if !sqlIdentifierRegex.MatchString(s) {
|
||||||
|
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeSQLString escapes single quotes in a string for SQL
|
||||||
|
func escapeSQLString(s string) string {
|
||||||
|
return strings.ReplaceAll(s, "'", "''")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Name() string {
|
||||||
|
s := string(f)
|
||||||
|
idx := strings.LastIndexByte(s, '.')
|
||||||
|
if idx == -1 {
|
||||||
|
return s // Return as-is if no dot
|
||||||
|
}
|
||||||
|
return s[idx+1:] // Return part after last dot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) String() string {
|
||||||
|
return string(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count function wrapped field
|
||||||
|
func (f Field) Count() Field {
|
||||||
|
return Field("COUNT(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConcatWs creates a CONCAT_WS SQL function.
|
||||||
|
// SECURITY: The sep parameter should only be a constant string, not user input.
|
||||||
|
// Single quotes in sep will be escaped automatically.
|
||||||
|
func ConcatWs(sep string, fields ...Field) Field {
|
||||||
|
escapedSep := escapeSQLString(sep)
|
||||||
|
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringAgg creates a STRING_AGG SQL function.
|
||||||
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||||
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||||
|
func StringAgg(field Field, sep string) Field {
|
||||||
|
escapedSep := escapeSQLString(sep)
|
||||||
|
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||||
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||||
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||||
|
func StringAggCast(field Field, sep string) Field {
|
||||||
|
escapedSep := escapeSQLString(sep)
|
||||||
|
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringEscape will wrap field with:
|
||||||
|
//
|
||||||
|
// COALESCE(field, ”)
|
||||||
|
func (f Field) StringEscape() Field {
|
||||||
|
return Field("COALESCE(" + f.String() + ", '')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumberEscape will wrap field with:
|
||||||
|
//
|
||||||
|
// COALESCE(field, 0)
|
||||||
|
func (f Field) NumberEscape() Field {
|
||||||
|
return Field("COALESCE(" + f.String() + ", 0)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BooleanEscape will wrap field with:
|
||||||
|
//
|
||||||
|
// COALESCE(field, FALSE)
|
||||||
|
func (f Field) BooleanEscape() Field {
|
||||||
|
return Field("COALESCE(" + f.String() + ", FALSE)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avg function wrapped field
|
||||||
|
func (f Field) Avg() Field {
|
||||||
|
return Field("AVG(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum function wrapped field
|
||||||
|
func (f Field) Sum() Field {
|
||||||
|
return Field("SUM(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max function wrapped field
|
||||||
|
func (f Field) Max() Field {
|
||||||
|
return Field("MAX(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min function wrapped field
|
||||||
|
func (f Field) Min() Field {
|
||||||
|
return Field("Min(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower function wrapped field
|
||||||
|
func (f Field) Lower() Field {
|
||||||
|
return Field("LOWER(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upper function wrapped field
|
||||||
|
func (f Field) Upper() Field {
|
||||||
|
return Field("UPPER(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim function wrapped field
|
||||||
|
func (f Field) Trim() Field {
|
||||||
|
return Field("TRIM(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asc suffixed field, supposed to be used with order by
|
||||||
|
func (f Field) Asc() Field {
|
||||||
|
return Field(f.String() + " ASC")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Desc suffixed field, supposed to be used with order by
|
||||||
|
func (f Field) Desc() Field {
|
||||||
|
return Field(f.String() + " DESC")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) RowNumber(as string) Field {
|
||||||
|
return rowNumber(&f, nil, true, as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) RowNumberDesc(as string) Field {
|
||||||
|
return rowNumber(&f, nil, true, as)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowNumberPartionBy in ascending order
|
||||||
|
func (f Field) RowNumberPartionBy(partition Field, as string) Field {
|
||||||
|
return rowNumber(&f, &partition, true, as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) RowNumberDescPartionBy(partition Field, as string) Field {
|
||||||
|
return rowNumber(&f, &partition, false, as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rowNumber(f, partition *Field, isAsc bool, as string) Field {
|
||||||
|
// Validate as parameter is a valid SQL identifier
|
||||||
|
if as != "" {
|
||||||
|
if err := validateSQLIdentifier(as); err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var orderBy string
|
||||||
|
if isAsc {
|
||||||
|
orderBy = " ASC"
|
||||||
|
} else {
|
||||||
|
orderBy = " DESC"
|
||||||
|
}
|
||||||
|
|
||||||
|
if as == "" {
|
||||||
|
as = "row_number"
|
||||||
|
}
|
||||||
|
|
||||||
|
col := f.String()
|
||||||
|
if partition != nil {
|
||||||
|
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + col + orderBy + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Field("ROW_NUMBER() OVER (ORDER BY " + col + orderBy + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) IsNull() Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) IsNotNull() Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DateTrunc will truncate date or timestamp to specified level of precision
|
||||||
|
//
|
||||||
|
// Level values:
|
||||||
|
// - microseconds, milliseconds, second, minute, hour
|
||||||
|
// - day, week (Monday start), month, quarter, year
|
||||||
|
// - decade, century, millennium
|
||||||
|
func (f Field) DateTrunc(level, as string) Field {
|
||||||
|
// Validate level parameter against allowed values
|
||||||
|
if !validDateTruncLevels[strings.ToLower(level)] {
|
||||||
|
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate as parameter is a valid SQL identifier
|
||||||
|
if err := validateSQLIdentifier(as); err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) TsRank(fieldName, as string) Field {
|
||||||
|
// Validate as parameter is a valid SQL identifier
|
||||||
|
if err := validateSQLIdentifier(as); err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
||||||
|
func (f Field) EqFold(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eq is equal
|
||||||
|
func (f Field) Eq(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) NotEq(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Gt(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Lt(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Gte(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Lte(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Like(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) LikeFold(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ILIKE is case-insensitive
|
||||||
|
func (f Field) ILike(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Any(val ...any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " = ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) NotAny(val ...any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " != ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotInSubQuery using ANY
|
||||||
|
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) TsQuery(as string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinFileds(fields []Field) string {
|
||||||
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
for i, f := range fields {
|
||||||
|
if i == 0 {
|
||||||
|
sb.WriteString(f.String())
|
||||||
|
} else {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
sb.WriteString(f.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
382
pgm_field_test.go
Normal file
382
pgm_field_test.go
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestValidateSQLIdentifier tests SQL identifier validation
|
||||||
|
func TestValidateSQLIdentifier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid simple", "column_name", false},
|
||||||
|
{"valid with numbers", "column123", false},
|
||||||
|
{"valid underscore", "_private", false},
|
||||||
|
{"valid mixed", "my_Column_123", false},
|
||||||
|
{"invalid with space", "column name", true},
|
||||||
|
{"invalid with dash", "column-name", true},
|
||||||
|
{"invalid with dot", "table.column", true},
|
||||||
|
{"invalid with quote", "column'name", true},
|
||||||
|
{"invalid with semicolon", "column;DROP", true},
|
||||||
|
{"invalid starts with number", "123column", true},
|
||||||
|
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
|
||||||
|
{"empty string", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateSQLIdentifier(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEscapeSQLString tests SQL string escaping
|
||||||
|
func TestEscapeSQLString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"no quotes", "hello", "hello"},
|
||||||
|
{"single quote", "hello'world", "hello''world"},
|
||||||
|
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
|
||||||
|
{"only quote", "'", "''"},
|
||||||
|
{"empty string", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := escapeSQLString(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
|
||||||
|
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sep string
|
||||||
|
fields []Field
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "safe separator",
|
||||||
|
sep: ", ",
|
||||||
|
fields: []Field{"col1", "col2"},
|
||||||
|
contains: "concat_ws(', ',col1, col2)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped quotes",
|
||||||
|
sep: "', (SELECT password FROM users), '",
|
||||||
|
fields: []Field{"col1"},
|
||||||
|
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single quote",
|
||||||
|
sep: "'",
|
||||||
|
fields: []Field{"col1"},
|
||||||
|
contains: "concat_ws('''',col1)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConcatWs(tt.sep, tt.fields...)
|
||||||
|
if !strings.Contains(string(result), tt.contains) {
|
||||||
|
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
|
||||||
|
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field Field
|
||||||
|
sep string
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "safe parameters",
|
||||||
|
field: Field("column_name"),
|
||||||
|
sep: ", ",
|
||||||
|
contains: "string_agg(column_name,', ')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped quotes in separator",
|
||||||
|
sep: "'; DROP TABLE users; --",
|
||||||
|
field: Field("col"),
|
||||||
|
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := StringAgg(tt.field, tt.sep)
|
||||||
|
if !strings.Contains(string(result), tt.contains) {
|
||||||
|
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||||
|
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||||
|
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
||||||
|
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDateTruncValidation tests DateTrunc level validation
|
||||||
|
func TestDateTruncValidation(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
|
||||||
|
validLevels := []string{
|
||||||
|
"microseconds", "milliseconds", "second", "minute", "hour",
|
||||||
|
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, level := range validLevels {
|
||||||
|
t.Run("valid_"+level, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.DateTrunc(level, "truncated")
|
||||||
|
if !strings.Contains(string(result), level) {
|
||||||
|
t.Errorf("DateTrunc result should contain level %q", level)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case-insensitive
|
||||||
|
t.Run("case_insensitive", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("DateTrunc should accept uppercase level")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.DateTrunc("MONTH", "truncated")
|
||||||
|
if !strings.Contains(strings.ToLower(string(result)), "month") {
|
||||||
|
t.Errorf("DateTrunc should normalize case")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
|
||||||
|
func TestDateTruncInvalidLevel(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
|
||||||
|
invalidLevels := []string{
|
||||||
|
"invalid",
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"day; DELETE FROM",
|
||||||
|
"",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, level := range invalidLevels {
|
||||||
|
t.Run("invalid_"+level, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.DateTrunc(level, "truncated")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
|
||||||
|
func TestDateTruncInvalidAlias(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
|
||||||
|
invalidAliases := []string{
|
||||||
|
"alias name",
|
||||||
|
"alias-name",
|
||||||
|
"'; DROP TABLE",
|
||||||
|
"123alias",
|
||||||
|
"alias.name",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range invalidAliases {
|
||||||
|
t.Run("invalid_alias_"+alias, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.DateTrunc("day", alias)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTsRankValidation tests TsRank alias validation
|
||||||
|
func TestTsRankValidation(t *testing.T) {
|
||||||
|
field := Field("search_vector")
|
||||||
|
|
||||||
|
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
|
||||||
|
for _, alias := range validAliases {
|
||||||
|
t.Run("valid_"+alias, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("TsRank should not panic for valid alias %q", alias)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.TsRank("query", alias)
|
||||||
|
if !strings.Contains(string(result), alias) {
|
||||||
|
t.Errorf("TsRank result should contain alias %q", alias)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidAliases := []string{
|
||||||
|
"'; DROP TABLE",
|
||||||
|
"rank name",
|
||||||
|
"rank-name",
|
||||||
|
"123rank",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range invalidAliases {
|
||||||
|
t.Run("invalid_"+alias, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("TsRank should panic for invalid alias %q", alias)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.TsRank("query", alias)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRowNumberValidation tests ROW_NUMBER alias validation
|
||||||
|
func TestRowNumberValidation(t *testing.T) {
|
||||||
|
field := Field("id")
|
||||||
|
|
||||||
|
t.Run("valid_alias", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("RowNumber should not panic for valid alias: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.RowNumber("row_num")
|
||||||
|
if !strings.Contains(string(result), "row_num") {
|
||||||
|
t.Errorf("RowNumber result should contain alias")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_alias", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("RowNumber should panic for invalid alias")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.RowNumber("'; DROP TABLE")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSQLInjectionAttackVectors tests common SQL injection patterns
|
||||||
|
func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||||
|
attacks := []string{
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"' OR '1'='1",
|
||||||
|
"'; DELETE FROM users WHERE '1'='1",
|
||||||
|
"1'; UPDATE users SET password='hacked'; --",
|
||||||
|
"admin'--",
|
||||||
|
"' UNION SELECT * FROM passwords--",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
for _, attack := range attacks {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.DateTrunc("day", attack)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TsRank_alias_protection", func(t *testing.T) {
|
||||||
|
field := Field("search_vector")
|
||||||
|
for _, attack := range attacks {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("TsRank should prevent SQL injection: %q", attack)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.TsRank("query", attack)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
|
||||||
|
for _, attack := range attacks {
|
||||||
|
result := ConcatWs(attack, Field("col1"))
|
||||||
|
// Check that quotes are escaped (doubled)
|
||||||
|
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
|
||||||
|
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldName tests Field.Name() method
|
||||||
|
func TestFieldName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field Field
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "field with table prefix",
|
||||||
|
field: Field("users.email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field with schema and table prefix",
|
||||||
|
field: Field("public.users.email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field without dot",
|
||||||
|
field: Field("email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field with multiple dots",
|
||||||
|
field: Field("schema.table.column"),
|
||||||
|
expected: "column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "aggregate function",
|
||||||
|
field: Field("COUNT(users.id)"),
|
||||||
|
expected: "id)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.field.Name()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
86
pgm_table.go
Normal file
86
pgm_table.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package pgm
|
||||||
|
|
||||||
|
// Table in database
|
||||||
|
type Table struct {
|
||||||
|
textSearch *textSearchCTE
|
||||||
|
|
||||||
|
Name string
|
||||||
|
DerivedTable Query
|
||||||
|
PK []string
|
||||||
|
FieldCount uint16
|
||||||
|
debug bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// text search Common Table Expression
|
||||||
|
type textSearchCTE struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
alias string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug when set true will print generated query string in stdout
|
||||||
|
func (t *Table) Debug() Clause {
|
||||||
|
t.debug = true
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Table) Field(f string) Field {
|
||||||
|
return Field(t.Name + "." + f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert table statement
|
||||||
|
func (t *Table) Insert() InsertClause {
|
||||||
|
qb := &insertQry{
|
||||||
|
debug: t.debug,
|
||||||
|
table: t.Name,
|
||||||
|
fields: make([]string, 0, t.FieldCount),
|
||||||
|
vals: make([]string, 0, t.FieldCount),
|
||||||
|
args: make([]any, 0, t.FieldCount),
|
||||||
|
}
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Table) WithTextSearch(name, alias, textToSearch string) *Table {
|
||||||
|
t.textSearch = &textSearchCTE{name: name, value: textToSearch, alias: alias}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select table statement
|
||||||
|
func (t *Table) Select(field ...Field) SelectClause {
|
||||||
|
qb := &selectQry{
|
||||||
|
debug: t.debug,
|
||||||
|
fields: field,
|
||||||
|
textSearch: t.textSearch,
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.DerivedTable != nil {
|
||||||
|
tName, args := t.DerivedTable.Build(true)
|
||||||
|
qb.table = "(" + tName + ") AS " + t.Name
|
||||||
|
qb.args = args
|
||||||
|
} else {
|
||||||
|
qb.table = t.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update table statement
|
||||||
|
func (t *Table) Update() UpdateClause {
|
||||||
|
qb := &updateQry{
|
||||||
|
debug: t.debug,
|
||||||
|
table: t.Name,
|
||||||
|
cols: make([]string, 0, t.FieldCount),
|
||||||
|
args: make([]any, 0, t.FieldCount),
|
||||||
|
}
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete table statement
|
||||||
|
func (t *Table) Delete() DeleteCluase {
|
||||||
|
qb := &deleteQry{
|
||||||
|
debug: t.debug,
|
||||||
|
table: t.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
return qb
|
||||||
|
}
|
||||||
670
pgm_test.go
Normal file
670
pgm_test.go
Normal file
@@ -0,0 +1,670 @@
|
|||||||
|
// Patial Tech.
|
||||||
|
// Author, Ankit Patial
|
||||||
|
|
||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
|
||||||
|
func TestGetPoolNotInitialized(t *testing.T) {
|
||||||
|
// Save current pool state
|
||||||
|
currentPool := poolPGX.Load()
|
||||||
|
|
||||||
|
// Clear the pool to simulate uninitialized state
|
||||||
|
poolPGX.Store(nil)
|
||||||
|
|
||||||
|
// Restore pool after test
|
||||||
|
defer func() {
|
||||||
|
poolPGX.Store(currentPool)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Verify panic occurs
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Error("GetPool should panic when pool not initialized")
|
||||||
|
} else {
|
||||||
|
// Check panic message
|
||||||
|
msg, ok := r.(string)
|
||||||
|
if !ok || !strings.Contains(msg, "not initialized") {
|
||||||
|
t.Errorf("Expected panic message about initialization, got: %v", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
GetPool() // Should panic
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConfigValidation tests that InitPool validates configuration
|
||||||
|
func TestConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantPanic bool
|
||||||
|
panicMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty connection string",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "",
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "connection string is empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinConns greater than MaxConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MaxConns: 5,
|
||||||
|
MinConns: 10,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "MinConns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative MaxConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MaxConns: -1,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "negative",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative MinConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MinConns: -1,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "negative",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if tt.wantPanic && r == nil {
|
||||||
|
t.Errorf("InitPool should panic for %s", tt.name)
|
||||||
|
}
|
||||||
|
if !tt.wantPanic && r != nil {
|
||||||
|
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
|
||||||
|
}
|
||||||
|
if tt.wantPanic && r != nil {
|
||||||
|
msg := ""
|
||||||
|
switch v := r.(type) {
|
||||||
|
case string:
|
||||||
|
msg = v
|
||||||
|
case error:
|
||||||
|
msg = v.Error()
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, tt.panicMsg) {
|
||||||
|
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
InitPool(tt.config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
|
||||||
|
func TestClosePoolIdempotent(t *testing.T) {
|
||||||
|
// Save current pool
|
||||||
|
currentPool := poolPGX.Load()
|
||||||
|
defer func() {
|
||||||
|
poolPGX.Store(currentPool)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set to nil
|
||||||
|
poolPGX.Store(nil)
|
||||||
|
|
||||||
|
// Should not panic when called on nil pool
|
||||||
|
ClosePool()
|
||||||
|
ClosePool()
|
||||||
|
ClosePool()
|
||||||
|
|
||||||
|
// Should work fine - no panic expected
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsNotFound tests the error checking utility
|
||||||
|
func TestIsNotFound(t *testing.T) {
|
||||||
|
// Test with pgx.ErrNoRows
|
||||||
|
if !IsNotFound(pgx.ErrNoRows) {
|
||||||
|
t.Error("IsNotFound should return true for pgx.ErrNoRows")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with other errors
|
||||||
|
otherErr := errors.New("some other error")
|
||||||
|
if IsNotFound(otherErr) {
|
||||||
|
t.Error("IsNotFound should return false for non-ErrNoRows errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with nil
|
||||||
|
if IsNotFound(nil) {
|
||||||
|
t.Error("IsNotFound should return false for nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgTime tests PostgreSQL timestamp conversion
|
||||||
|
func TestPgTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
pgTime := PgTime(now)
|
||||||
|
|
||||||
|
if !pgTime.Valid {
|
||||||
|
t.Error("PgTime should return valid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pgTime.Time.Equal(now) {
|
||||||
|
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgTimeNow tests current time conversion
|
||||||
|
func TestPgTimeNow(t *testing.T) {
|
||||||
|
before := time.Now()
|
||||||
|
pgTime := PgTimeNow()
|
||||||
|
after := time.Now()
|
||||||
|
|
||||||
|
if !pgTime.Valid {
|
||||||
|
t.Error("PgTimeNow should return valid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that time is between before and after
|
||||||
|
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
|
||||||
|
t.Error("PgTimeNow should return current time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTsQueryFunctions tests full-text search query builders
|
||||||
|
func TestTsQueryFunctions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(string) string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TsAndQuery",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello & world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsPrefixAndQuery",
|
||||||
|
fn: TsPrefixAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello:* & world:*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsOrQuery",
|
||||||
|
fn: TsOrQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello | world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsPrefixOrQuery",
|
||||||
|
fn: TsPrefixOrQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello:* | world:*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsAndQuery with multiple words",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "one two three",
|
||||||
|
expected: "one & two & three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsOrQuery with multiple words",
|
||||||
|
fn: TsOrQuery,
|
||||||
|
input: "one two three",
|
||||||
|
expected: "one | two | three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsAndQuery with extra spaces",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello & world",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.fn(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldsWithSuffix tests the internal suffix function
|
||||||
|
func TestFieldsWithSuffix(t *testing.T) {
|
||||||
|
result := fieldsWithSufix("hello world", ":*")
|
||||||
|
expected := []string{"hello:*", "world:*"}
|
||||||
|
|
||||||
|
if len(result) != len(expected) {
|
||||||
|
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range result {
|
||||||
|
if v != expected[i] {
|
||||||
|
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
|
||||||
|
func TestBeginTxErrorWrapping(t *testing.T) {
|
||||||
|
// This test verifies basic behavior without requiring a database connection
|
||||||
|
// Actual error wrapping would require a failing database connection
|
||||||
|
|
||||||
|
// Skip if no pool initialized (unit test environment)
|
||||||
|
if poolPGX.Load() == nil {
|
||||||
|
t.Skip("Skipping BeginTx test - requires initialized pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just verify the function accepts a context
|
||||||
|
// In a real scenario with a cancelled context, it would return an error
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// We can't fully test without a real DB connection, so just document behavior
|
||||||
|
_, err := BeginTx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Error is expected in test environment without valid DB
|
||||||
|
t.Logf("BeginTx returned error (expected in test env): %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContextCancellation tests that cancelled context is handled
|
||||||
|
func TestContextCancellation(t *testing.T) {
|
||||||
|
// Create an already-cancelled context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
// Note: This test would require an actual database connection to verify
|
||||||
|
// that the context cancellation is properly propagated. We're testing
|
||||||
|
// that the context is accepted by the function signature.
|
||||||
|
|
||||||
|
// The actual behavior would be tested in integration tests
|
||||||
|
// For now, we just verify the function doesn't panic with cancelled context
|
||||||
|
|
||||||
|
// Skip if no pool initialized (unit test environment)
|
||||||
|
if poolPGX.Load() == nil {
|
||||||
|
t.Skip("Skipping context cancellation test - requires initialized pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// This would return an error about cancelled context in real scenario
|
||||||
|
_, err := BeginTx(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Log("BeginTx with cancelled context - error expected in real scenario")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestErrorTypes verifies exported error variables
|
||||||
|
func TestErrorTypes(t *testing.T) {
|
||||||
|
if ErrConnStringMissing == nil {
|
||||||
|
t.Error("ErrConnStringMissing should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrInitTX == nil {
|
||||||
|
t.Error("ErrInitTX should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrCommitTX == nil {
|
||||||
|
t.Error("ErrCommitTX should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrNoRows == nil {
|
||||||
|
t.Error("ErrNoRows should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error messages are descriptive
|
||||||
|
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
|
||||||
|
t.Error("ErrConnStringMissing should mention connection string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringPoolGetPut tests the string builder pool
|
||||||
|
func TestStringPoolGetPut(t *testing.T) {
|
||||||
|
// Get a string builder from pool
|
||||||
|
sb1 := getSB()
|
||||||
|
if sb1 == nil {
|
||||||
|
t.Fatal("getSB should return non-nil string builder")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use it
|
||||||
|
sb1.WriteString("test")
|
||||||
|
if sb1.String() != "test" {
|
||||||
|
t.Error("String builder should work normally")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put it back
|
||||||
|
putSB(sb1)
|
||||||
|
|
||||||
|
// Get another one - should be reset
|
||||||
|
sb2 := getSB()
|
||||||
|
if sb2 == nil {
|
||||||
|
t.Fatal("getSB should return non-nil string builder")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be empty after reset
|
||||||
|
if sb2.Len() != 0 {
|
||||||
|
t.Error("String builder from pool should be reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
putSB(sb2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConditionActionTypes tests condition action enum
|
||||||
|
func TestConditionActionTypes(t *testing.T) {
|
||||||
|
// Verify enum values are distinct
|
||||||
|
actions := []CondAction{
|
||||||
|
CondActionNothing,
|
||||||
|
CondActionNeedToClose,
|
||||||
|
CondActionSubQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[CondAction]bool)
|
||||||
|
for _, action := range actions {
|
||||||
|
if seen[action] {
|
||||||
|
t.Errorf("Duplicate CondAction value: %v", action)
|
||||||
|
}
|
||||||
|
seen[action] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectQueryBuilderBasics tests basic query building
|
||||||
|
func TestSelectQueryBuilderBasics(t *testing.T) {
|
||||||
|
// Create a test table
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
emailField := Field("users.email")
|
||||||
|
|
||||||
|
// Test basic select
|
||||||
|
qry := testTable.Select(idField, emailField)
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "SELECT") {
|
||||||
|
t.Error("Query should contain SELECT")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "users.id") {
|
||||||
|
t.Error("Query should contain users.id field")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "users.email") {
|
||||||
|
t.Error("Query should contain users.email field")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "FROM users") {
|
||||||
|
t.Error("Query should contain FROM users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
|
||||||
|
func TestWhereConditionAccumulation(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
statusField := Field("users.status")
|
||||||
|
|
||||||
|
// Create a query and add multiple WHERE conditions
|
||||||
|
qry := testTable.Select(idField).
|
||||||
|
Where(idField.Eq(1)).
|
||||||
|
Where(statusField.Eq("active"))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
// Both conditions should be present
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have multiple conditions (this demonstrates accumulation)
|
||||||
|
if strings.Count(sql, "$") < 2 {
|
||||||
|
t.Error("Query should have multiple parameterized values")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInsertQueryBuilder tests insert query building
|
||||||
|
func TestInsertQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
|
||||||
|
qry := testTable.Insert()
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "INSERT INTO users") {
|
||||||
|
t.Error("Query should contain INSERT INTO users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateQueryBuilder tests update query building
|
||||||
|
func TestUpdateQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
qry := testTable.Update().
|
||||||
|
Where(idField.Eq(1))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "UPDATE users") {
|
||||||
|
t.Error("Query should contain UPDATE users")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteQueryBuilder tests delete query building
|
||||||
|
func TestDeleteQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
qry := testTable.Delete().Where(idField.Eq(1))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "DELETE FROM users") {
|
||||||
|
t.Error("Query should contain DELETE FROM users")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldConditioners tests various field condition methods
|
||||||
|
func TestFieldConditioners(t *testing.T) {
|
||||||
|
field := Field("users.age")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
condition Conditioner
|
||||||
|
checkFunc func(string) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Eq",
|
||||||
|
condition: field.Eq(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NotEq",
|
||||||
|
condition: field.NotEq(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gt",
|
||||||
|
condition: field.Gt(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lt",
|
||||||
|
condition: field.Lt(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gte",
|
||||||
|
condition: field.Gte(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lte",
|
||||||
|
condition: field.Lte(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IsNull",
|
||||||
|
condition: field.IsNull(),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IsNotNull",
|
||||||
|
condition: field.IsNotNull(),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Build a simple query to see the condition
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
qry := testTable.Select(field).Where(tt.condition)
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !tt.checkFunc(sql) {
|
||||||
|
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldFunctions tests SQL function wrappers
|
||||||
|
func TestFieldFunctions(t *testing.T) {
|
||||||
|
field := Field("users.count")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result Field
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Count",
|
||||||
|
result: field.Count(),
|
||||||
|
contains: "COUNT(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Sum",
|
||||||
|
result: field.Sum(),
|
||||||
|
contains: "SUM(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Avg",
|
||||||
|
result: field.Avg(),
|
||||||
|
contains: "AVG(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max",
|
||||||
|
result: field.Max(),
|
||||||
|
contains: "MAX(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Min",
|
||||||
|
result: field.Min(),
|
||||||
|
contains: "Min(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lower",
|
||||||
|
result: field.Lower(),
|
||||||
|
contains: "LOWER(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Upper",
|
||||||
|
result: field.Upper(),
|
||||||
|
contains: "UPPER(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Trim",
|
||||||
|
result: field.Trim(),
|
||||||
|
contains: "TRIM(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Asc",
|
||||||
|
result: field.Asc(),
|
||||||
|
contains: " ASC",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Desc",
|
||||||
|
result: field.Desc(),
|
||||||
|
contains: " DESC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resultStr := string(tt.result)
|
||||||
|
if !strings.Contains(resultStr, tt.contains) {
|
||||||
|
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInsertQueryValidation tests that Insert queries validate fields are set
|
||||||
|
func TestInsertQueryValidation(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
t.Run("Exec without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert()
|
||||||
|
err := qry.Exec(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.Exec() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecTx without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert()
|
||||||
|
// We can't test ExecTx without a real transaction, but we can test the error is defined
|
||||||
|
err := qry.ExecTx(context.Background(), nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.ExecTx() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("First without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert().Returning(idField)
|
||||||
|
var id int
|
||||||
|
err := qry.First(context.Background(), &id)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.First() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FirstTx without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert().Returning(idField)
|
||||||
|
var id int
|
||||||
|
err := qry.FirstTx(context.Background(), nil, &id)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.FirstTx() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package branchuser
|
package branchuser
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table branch_users
|
||||||
|
All pgm.Field = "branch_users.*"
|
||||||
// BranchID field has db type "bigint NOT NULL"
|
// BranchID field has db type "bigint NOT NULL"
|
||||||
BranchID pgm.Field = "branch_users.branch_id"
|
BranchID pgm.Field = "branch_users.branch_id"
|
||||||
// UserID field has db type "bigint NOT NULL"
|
// UserID field has db type "bigint NOT NULL"
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package comment
|
package comment
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table comments
|
||||||
|
All pgm.Field = "comments.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "comments.id"
|
ID pgm.Field = "comments.id"
|
||||||
// PostID field has db type "integer NOT NULL"
|
// PostID field has db type "integer NOT NULL"
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package employee
|
package employee
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table employees
|
||||||
|
All pgm.Field = "employees.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "employees.id"
|
ID pgm.Field = "employees.id"
|
||||||
// Name field has db type "var NOT NULL"
|
// Name field has db type "var NOT NULL"
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package post
|
package post
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table posts
|
||||||
|
All pgm.Field = "posts.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "posts.id"
|
ID pgm.Field = "posts.id"
|
||||||
// UserID field has db type "integer NOT NULL"
|
// UserID field has db type "integer NOT NULL"
|
||||||
@@ -1,15 +1,22 @@
|
|||||||
// Code generated by code.patial.tech/go/pgm/cmd
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
// DO NOT EDIT.
|
|
||||||
|
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
User = pgm.Table{Name: "users", FieldCount: 11}
|
User = pgm.Table{Name: "users", FieldCount: 12}
|
||||||
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
||||||
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
||||||
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
||||||
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
||||||
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||||
|
t := pgm.Table{
|
||||||
|
Name: tblName,
|
||||||
|
DerivedTable: fromQry,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package user
|
package user
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table users
|
||||||
|
All pgm.Field = "users.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "users.id"
|
ID pgm.Field = "users.id"
|
||||||
// Name field has db type "character varying(255) NOT NULL"
|
// Name field has db type "character varying(255) NOT NULL"
|
||||||
@@ -23,6 +25,8 @@ const (
|
|||||||
StatusID pgm.Field = "users.status_id"
|
StatusID pgm.Field = "users.status_id"
|
||||||
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
||||||
MfaKind pgm.Field = "users.mfa_kind"
|
MfaKind pgm.Field = "users.mfa_kind"
|
||||||
|
// SearchVector field has db type "tsvector"
|
||||||
|
SearchVector pgm.Field = "users.search_vector"
|
||||||
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||||
CreatedAt pgm.Field = "users.created_at"
|
CreatedAt pgm.Field = "users.created_at"
|
||||||
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package usersession
|
package usersession
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table user_sessions
|
||||||
|
All pgm.Field = "user_sessions.*"
|
||||||
// ID field has db type "character varying NOT NULL"
|
// ID field has db type "character varying NOT NULL"
|
||||||
ID pgm.Field = "user_sessions.id"
|
ID pgm.Field = "user_sessions.id"
|
||||||
// CreatedAt field has db type "timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL"
|
// CreatedAt field has db type "timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL"
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
//go:generate go run code.patial.tech/go/pgm/cmd -o ./db ./schema.sql
|
//go:generate go run code.patial.tech/go/pgm/cmd -o ./db ./schema.sql
|
||||||
|
|
||||||
package example
|
package playground
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
package example
|
package playground
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"code.patial.tech/go/pgm/example/db"
|
"code.patial.tech/go/pgm/playground/db"
|
||||||
"code.patial.tech/go/pgm/example/db/user"
|
"code.patial.tech/go/pgm/playground/db/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
expected := "DELETE FROM users WHERE users.id = $1 AND users.status_id NOT IN($2)"
|
expected := "DELETE FROM users WHERE users.id = $1 AND users.status_id != ANY($2)"
|
||||||
got := db.User.Delete().
|
got := db.User.Delete().
|
||||||
Where(user.ID.Eq(1), user.StatusID.NotIn(1, 2, 3)).
|
Where(user.ID.Eq(1), user.StatusID.NotAny(1, 2, 3)).
|
||||||
String()
|
String()
|
||||||
if got != expected {
|
if got != expected {
|
||||||
t.Errorf("got %q, want %q", got, expected)
|
t.Errorf("got %q, want %q", got, expected)
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package example
|
package playground
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"code.patial.tech/go/pgm"
|
"code.patial.tech/go/pgm"
|
||||||
"code.patial.tech/go/pgm/example/db"
|
"code.patial.tech/go/pgm/playground/db"
|
||||||
"code.patial.tech/go/pgm/example/db/user"
|
"code.patial.tech/go/pgm/playground/db/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestInsertQuery(t *testing.T) {
|
func TestInsertQuery(t *testing.T) {
|
||||||
@@ -17,24 +17,7 @@ func TestInsertQuery(t *testing.T) {
|
|||||||
Returning(user.ID).
|
Returning(user.ID).
|
||||||
String()
|
String()
|
||||||
|
|
||||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id"
|
||||||
if got != expected {
|
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInsertSetMap(t *testing.T) {
|
|
||||||
got := db.User.Insert().
|
|
||||||
SetMap(map[pgm.Field]any{
|
|
||||||
user.Email: "aa@aa.com",
|
|
||||||
user.Phone: 8889991234,
|
|
||||||
user.FirstName: "fname",
|
|
||||||
user.LastName: "lname",
|
|
||||||
}).
|
|
||||||
Returning(user.ID).
|
|
||||||
String()
|
|
||||||
|
|
||||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
|
||||||
if got != expected {
|
if got != expected {
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
}
|
}
|
||||||
@@ -53,7 +36,7 @@ func TestInsertQuery2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkInsertQuery-12 1952412 605.3 ns/op 1144 B/op 18 allocs/op
|
// BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op
|
||||||
func BenchmarkInsertQuery(b *testing.B) {
|
func BenchmarkInsertQuery(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Insert().
|
_ = db.User.Insert().
|
||||||
@@ -67,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op
|
// BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op
|
||||||
func BenchmarkInsertSetMap(b *testing.B) {
|
func BenchmarkInsertSetMap(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Insert().
|
_ = db.User.Insert().
|
||||||
159
playground/qry_select_test.go
Normal file
159
playground/qry_select_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package playground
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"code.patial.tech/go/pgm"
|
||||||
|
"code.patial.tech/go/pgm/playground/db"
|
||||||
|
"code.patial.tech/go/pgm/playground/db/branchuser"
|
||||||
|
"code.patial.tech/go/pgm/playground/db/employee"
|
||||||
|
"code.patial.tech/go/pgm/playground/db/user"
|
||||||
|
"code.patial.tech/go/pgm/playground/db/usersession"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQryBuilder2(t *testing.T) {
|
||||||
|
got := db.User.Debug().Select(user.Email, user.FirstName).
|
||||||
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
|
Join(db.BranchUser, user.ID, branchuser.UserID).
|
||||||
|
Where(
|
||||||
|
user.ID.Eq(1),
|
||||||
|
pgm.Or(
|
||||||
|
user.StatusID.Eq(2),
|
||||||
|
user.UpdatedAt.Eq(3),
|
||||||
|
),
|
||||||
|
user.MfaKind.Eq(4),
|
||||||
|
pgm.Or(
|
||||||
|
user.FirstName.Eq(5),
|
||||||
|
user.MiddleName.Eq(6),
|
||||||
|
),
|
||||||
|
).
|
||||||
|
Where(
|
||||||
|
user.LastName.NotEq(7),
|
||||||
|
user.Phone.Like("%123%"),
|
||||||
|
user.UpdatedAt.IsNotNull(),
|
||||||
|
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||||
|
).
|
||||||
|
Limit(10).
|
||||||
|
Offset(100).
|
||||||
|
String()
|
||||||
|
|
||||||
|
expected := "SELECT users.email, users.first_name FROM users JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||||
|
" JOIN branch_users ON users.id = branch_users.user_id WHERE users.id = $1 AND (users.status_id = $2 OR users.updated_at = $3)" +
|
||||||
|
" AND users.mfa_kind = $4 AND (users.first_name = $5 OR users.middle_name = $6) AND users.last_name != $7 AND users.phone" +
|
||||||
|
" LIKE $8 AND users.updated_at IS NOT NULL AND users.email != ANY(SELECT users.id FROM users WHERE users.id = $9) LIMIT 10 OFFSET 100"
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithHaving(t *testing.T) {
|
||||||
|
expected := "SELECT employees.department, AVG(employees.salary), COUNT(employees.id)" +
|
||||||
|
" FROM employees GROUP BY employees.department HAVING AVG(employees.salary) > $1 AND COUNT(employees.id) > $2"
|
||||||
|
got := db.Employee.
|
||||||
|
Select(employee.Department, employee.Salary.Avg(), employee.ID.Count()).
|
||||||
|
GroupBy(employee.Department).
|
||||||
|
Having(employee.Salary.Avg().Gt(50000), employee.ID.Count().Gt(5)).
|
||||||
|
String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithJoin(t *testing.T) {
|
||||||
|
got := db.User.Select(user.Email, user.FirstName).
|
||||||
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
|
LeftJoin(db.BranchUser, user.ID, branchuser.UserID, pgm.Or(branchuser.RoleID.Eq("1"), branchuser.RoleID.Eq("2"))).
|
||||||
|
Where(
|
||||||
|
user.ID.Eq(3),
|
||||||
|
pgm.Or(
|
||||||
|
user.StatusID.Eq(4),
|
||||||
|
user.UpdatedAt.Eq(5),
|
||||||
|
),
|
||||||
|
).
|
||||||
|
Limit(10).
|
||||||
|
Offset(100).
|
||||||
|
String()
|
||||||
|
|
||||||
|
expected := "SELECT users.email, users.first_name " +
|
||||||
|
"FROM users JOIN user_sessions ON users.id = user_sessions.user_id " +
|
||||||
|
"LEFT JOIN branch_users ON users.id = branch_users.user_id AND (branch_users.role_id = $1 OR branch_users.role_id = $2) " +
|
||||||
|
"WHERE users.id = $3 AND (users.status_id = $4 OR users.updated_at = $5) " +
|
||||||
|
"LIMIT 10 OFFSET 100"
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectDerived(t *testing.T) {
|
||||||
|
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||||
|
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||||
|
" ORDER BY t.status_id, t.created_at DESC"
|
||||||
|
|
||||||
|
qry := db.User.
|
||||||
|
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||||
|
Where(user.StatusID.Eq(1))
|
||||||
|
|
||||||
|
tbl := db.DerivedTable("t", qry)
|
||||||
|
got := tbl.
|
||||||
|
Select(tbl.Field("*")).
|
||||||
|
Where(tbl.Field("rn").Lte(5)).
|
||||||
|
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||||
|
String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectTV(t *testing.T) {
|
||||||
|
expected := "WITH ts AS (SELECT to_tsquery('english', $1) AS query)" +
|
||||||
|
" SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, ts.query) AS rank" +
|
||||||
|
" FROM users" +
|
||||||
|
" JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||||
|
" CROSS JOIN ts" +
|
||||||
|
" WHERE users.status_id = $2 AND users.search_vector @@ ts.query" +
|
||||||
|
" ORDER BY rank DESC"
|
||||||
|
|
||||||
|
qry := db.User.
|
||||||
|
WithTextSearch("ts", "query", "text to search").
|
||||||
|
Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("ts.query", "rank")).
|
||||||
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
|
Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("ts.query")).
|
||||||
|
OrderBy(pgm.Field("rank").Desc())
|
||||||
|
|
||||||
|
got := qry.String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
||||||
|
func BenchmarkSelect(b *testing.B) {
|
||||||
|
for b.Loop() {
|
||||||
|
_ = db.User.Select(user.Email, user.FirstName).
|
||||||
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
|
Join(db.BranchUser, user.ID, branchuser.UserID).
|
||||||
|
Where(
|
||||||
|
user.ID.Eq(1),
|
||||||
|
pgm.Or(
|
||||||
|
user.StatusID.Eq(2),
|
||||||
|
user.UpdatedAt.Eq(3),
|
||||||
|
),
|
||||||
|
user.MfaKind.Eq(4),
|
||||||
|
pgm.Or(
|
||||||
|
user.FirstName.Eq(5),
|
||||||
|
user.MiddleName.Eq(6),
|
||||||
|
),
|
||||||
|
).
|
||||||
|
Where(
|
||||||
|
user.LastName.NotEq(7),
|
||||||
|
user.Phone.Like("%123%"),
|
||||||
|
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||||
|
).
|
||||||
|
Limit(10).
|
||||||
|
Offset(100).
|
||||||
|
String()
|
||||||
|
}
|
||||||
|
}
|
||||||
61
playground/qry_update_test.go
Normal file
61
playground/qry_update_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package playground
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"code.patial.tech/go/pgm/playground/db"
|
||||||
|
"code.patial.tech/go/pgm/playground/db/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateQuery(t *testing.T) {
|
||||||
|
got := db.User.Update().
|
||||||
|
Set(user.FirstName, "ankit").
|
||||||
|
Set(user.MiddleName, "singh").
|
||||||
|
Set(user.LastName, "patial").
|
||||||
|
Where(
|
||||||
|
user.Email.Eq("aa@aa.com"),
|
||||||
|
).
|
||||||
|
Where(
|
||||||
|
user.StatusID.NotEq(1),
|
||||||
|
).
|
||||||
|
String()
|
||||||
|
|
||||||
|
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateQueryValidation(t *testing.T) {
|
||||||
|
// Test that UPDATE without Set() returns error
|
||||||
|
err := db.User.Update().
|
||||||
|
Where(user.Email.Eq("aa@aa.com")).
|
||||||
|
Exec(context.Background())
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when calling Exec() without Set(), got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "no columns to update") {
|
||||||
|
t.Errorf("Expected error message to contain 'no columns to update', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op
|
||||||
|
func BenchmarkUpdateQuery(b *testing.B) {
|
||||||
|
for b.Loop() {
|
||||||
|
_ = db.User.Update().
|
||||||
|
Set(user.FirstName, "ankit").
|
||||||
|
Set(user.MiddleName, "singh").
|
||||||
|
Set(user.LastName, "patial").
|
||||||
|
Where(
|
||||||
|
user.Email.Eq("aa@aa.com"),
|
||||||
|
).
|
||||||
|
Where(
|
||||||
|
user.StatusID.NotEq(1),
|
||||||
|
).
|
||||||
|
String()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -61,6 +61,7 @@ CREATE TABLE public.users (
|
|||||||
last_name character varying(50) NOT NULL,
|
last_name character varying(50) NOT NULL,
|
||||||
status_id smallint,
|
status_id smallint,
|
||||||
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
||||||
|
search_vector tsvector,
|
||||||
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
98
pool.go
98
pool.go
@@ -1,98 +0,0 @@
|
|||||||
// Patial Tech.
|
|
||||||
// Author, Ankit Patial
|
|
||||||
|
|
||||||
package pgm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
poolPGX atomic.Pointer[pgxpool.Pool]
|
|
||||||
poolStringBuilder = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return new(strings.Builder)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ErrInitTX = errors.New("failed to init db.tx")
|
|
||||||
ErrCommitTX = errors.New("failed to commit db.tx")
|
|
||||||
ErrNoRows = errors.New("no data found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
MaxConns int32
|
|
||||||
MinConns int32
|
|
||||||
MaxConnLifetime time.Duration
|
|
||||||
MaxConnIdleTime time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func Init(connString string, conf *Config) {
|
|
||||||
cfg, err := pgxpool.ParseConfig(connString)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf != nil {
|
|
||||||
if conf.MaxConns > 0 {
|
|
||||||
cfg.MaxConns = conf.MaxConns // 100
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MinConns > 0 {
|
|
||||||
cfg.MinConns = conf.MaxConns // 5
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MaxConnLifetime > 0 {
|
|
||||||
cfg.MaxConnLifetime = conf.MaxConnLifetime // time.Minute * 10
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MaxConnIdleTime > 0 {
|
|
||||||
cfg.MaxConnIdleTime = conf.MaxConnIdleTime // time.Minute * 5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := pgxpool.NewWithConfig(context.Background(), cfg)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = p.Ping(context.Background()); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
poolPGX.Store(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPool() *pgxpool.Pool {
|
|
||||||
return poolPGX.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// get string builder from pool
|
|
||||||
func getSB() *strings.Builder {
|
|
||||||
return poolStringBuilder.Get().(*strings.Builder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// put string builder back to pool
|
|
||||||
func putSB(sb *strings.Builder) {
|
|
||||||
sb.Reset()
|
|
||||||
poolStringBuilder.Put(sb)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
|
||||||
tx, err := poolPGX.Load().Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error(err.Error())
|
|
||||||
return nil, errors.New("failed to open db tx")
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx, err
|
|
||||||
}
|
|
||||||
200
qry.go
200
qry.go
@@ -7,156 +7,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Clause interface {
|
Clause interface {
|
||||||
|
Insert() InsertClause
|
||||||
Select(fields ...Field) SelectClause
|
Select(fields ...Field) SelectClause
|
||||||
// Insert() InsertSet
|
Update() UpdateClause
|
||||||
// Update() UpdateSet
|
Delete() DeleteCluase
|
||||||
// Delete() WhereOrExec
|
|
||||||
}
|
|
||||||
|
|
||||||
SelectClause interface {
|
|
||||||
// Join and Inner Join are same
|
|
||||||
Join(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
LeftJoin(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
RightJoin(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
FullJoin(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
CrossJoin(m Table) SelectClause
|
|
||||||
WhereClause
|
|
||||||
OrderByClause
|
|
||||||
GroupByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
raw(prefixArgs []any) (string, []any)
|
|
||||||
}
|
|
||||||
|
|
||||||
WhereClause interface {
|
|
||||||
Where(cond ...Conditioner) AfterWhere
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterWhere interface {
|
|
||||||
WhereClause
|
|
||||||
GroupByClause
|
|
||||||
OrderByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
GroupByClause interface {
|
|
||||||
GroupBy(fields ...Field) AfterGroupBy
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterGroupBy interface {
|
|
||||||
HavinClause
|
|
||||||
OrderByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
HavinClause interface {
|
|
||||||
Having(cond ...Conditioner) AfterHaving
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterHaving interface {
|
|
||||||
OrderByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
OrderByClause interface {
|
|
||||||
OrderBy(fields ...Field) AfterOrderBy
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterOrderBy interface {
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
LimitClause interface {
|
|
||||||
Limit(v int) AfterLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterLimit interface {
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
OffsetClause interface {
|
|
||||||
Offset(v int) AfterOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterOffset interface {
|
|
||||||
LimitClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
Conditioner interface {
|
|
||||||
Condition(args *[]any, idx int) string
|
|
||||||
}
|
|
||||||
|
|
||||||
Insert interface {
|
|
||||||
Set(field Field, val any) InsertClause
|
|
||||||
SetMap(fields map[Field]any) InsertClause
|
|
||||||
}
|
|
||||||
|
|
||||||
InsertClause interface {
|
|
||||||
Insert
|
|
||||||
Returning(field Field) First
|
|
||||||
OnConflict(fields ...Field) Do
|
|
||||||
Execute
|
|
||||||
Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
Do interface {
|
|
||||||
DoNothing() Execute
|
|
||||||
DoUpdate(fields ...Field) Execute
|
|
||||||
}
|
|
||||||
|
|
||||||
Update interface {
|
|
||||||
Set(field Field, val any) UpdateClause
|
|
||||||
SetMap(fields map[Field]any) UpdateClause
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateClause interface {
|
|
||||||
Update
|
|
||||||
Where(cond ...Conditioner) WhereOrExec
|
|
||||||
}
|
|
||||||
|
|
||||||
WhereOrExec interface {
|
|
||||||
Where(cond ...Conditioner) WhereOrExec
|
|
||||||
Execute
|
|
||||||
}
|
|
||||||
|
|
||||||
Query interface {
|
|
||||||
First
|
|
||||||
All
|
|
||||||
Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
First interface {
|
|
||||||
First(ctx context.Context, dest ...any) error
|
|
||||||
FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error
|
|
||||||
Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
All interface {
|
|
||||||
// Query rows
|
|
||||||
//
|
|
||||||
// don't forget to close() rows
|
|
||||||
All(ctx context.Context, rows RowsCb) error
|
|
||||||
// Query rows
|
|
||||||
//
|
|
||||||
// don't forget to close() rows
|
|
||||||
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute interface {
|
Execute interface {
|
||||||
@@ -169,26 +30,26 @@ type (
|
|||||||
String() string
|
String() string
|
||||||
}
|
}
|
||||||
|
|
||||||
RowScanner interface {
|
Conditioner interface {
|
||||||
Scan(dest ...any) error
|
Condition(args *[]any, idx int) string
|
||||||
}
|
}
|
||||||
|
|
||||||
RowsCb func(row RowScanner) error
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func joinFileds(fields []Field) string {
|
var sbPool = sync.Pool{
|
||||||
sb := getSB()
|
New: func() any {
|
||||||
defer putSB(sb)
|
return new(strings.Builder)
|
||||||
for i, f := range fields {
|
},
|
||||||
if i == 0 {
|
}
|
||||||
sb.WriteString(f.String())
|
|
||||||
} else {
|
|
||||||
sb.WriteString(", ")
|
|
||||||
sb.WriteString(f.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String()
|
// get string builder from pool
|
||||||
|
func getSB() *strings.Builder {
|
||||||
|
return sbPool.Get().(*strings.Builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// put string builder back to pool
|
||||||
|
func putSB(sb *strings.Builder) {
|
||||||
|
sb.Reset()
|
||||||
|
sbPool.Put(sb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func And(cond ...Conditioner) Conditioner {
|
func And(cond ...Conditioner) Conditioner {
|
||||||
@@ -208,12 +69,16 @@ func (cv *Cond) Condition(args *[]any, argIdx int) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. normal condition
|
// 2. normal condition
|
||||||
*args = append(*args, cv.Val)
|
|
||||||
var op string
|
var op string
|
||||||
if strings.HasSuffix(cv.op, "$") {
|
if cv.Val != nil {
|
||||||
op = cv.op + strconv.Itoa(argIdx+1)
|
*args = append(*args, cv.Val)
|
||||||
|
if strings.HasSuffix(cv.op, "$") {
|
||||||
|
op = cv.op + strconv.Itoa(argIdx+1)
|
||||||
|
} else {
|
||||||
|
op = strings.Replace(cv.op, "$", "$"+strconv.Itoa(argIdx+1), 1)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
op = strings.Replace(cv.op, "$", "$"+strconv.Itoa(argIdx+1), 1)
|
op = cv.op
|
||||||
}
|
}
|
||||||
|
|
||||||
if cv.action == CondActionNeedToClose {
|
if cv.action == CondActionNeedToClose {
|
||||||
@@ -227,13 +92,14 @@ func (c *CondGroup) Condition(args *[]any, argIdx int) string {
|
|||||||
defer putSB(sb)
|
defer putSB(sb)
|
||||||
|
|
||||||
sb.WriteString("(")
|
sb.WriteString("(")
|
||||||
|
currentIdx := argIdx
|
||||||
for i, cond := range c.cond {
|
for i, cond := range c.cond {
|
||||||
if i == 0 {
|
if i > 0 {
|
||||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
|
||||||
} else {
|
|
||||||
sb.WriteString(c.op)
|
sb.WriteString(c.op)
|
||||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
|
||||||
}
|
}
|
||||||
|
argsLenBefore := len(*args)
|
||||||
|
sb.WriteString(cond.Condition(args, currentIdx))
|
||||||
|
currentIdx += len(*args) - argsLenBefore
|
||||||
}
|
}
|
||||||
sb.WriteString(")")
|
sb.WriteString(")")
|
||||||
return sb.String()
|
return sb.String()
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
DeleteCluase interface {
|
||||||
|
WhereOrExec
|
||||||
|
}
|
||||||
|
|
||||||
deleteQry struct {
|
deleteQry struct {
|
||||||
table string
|
table string
|
||||||
condition []Conditioner
|
condition []Conditioner
|
||||||
@@ -15,14 +19,11 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Table) Delete() WhereOrExec {
|
// WARNING: DELETE without WHERE clause will delete ALL rows in the table.
|
||||||
qb := &deleteQry{
|
// Always use Where() to specify conditions unless you intentionally want to delete all rows.
|
||||||
table: t.Name,
|
// Example:
|
||||||
debug: t.debug,
|
// table.Delete().Where(field.Eq(value)).Exec(ctx) // Safe
|
||||||
}
|
// table.Delete().Exec(ctx) // Deletes ALL rows!
|
||||||
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *deleteQry) Where(cond ...Conditioner) WhereOrExec {
|
func (q *deleteQry) Where(cond ...Conditioner) WhereOrExec {
|
||||||
q.condition = append(q.condition, cond...)
|
q.condition = append(q.condition, cond...)
|
||||||
|
|||||||
@@ -5,36 +5,40 @@ package pgm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type insertQry struct {
|
type (
|
||||||
returing *string
|
InsertClause interface {
|
||||||
onConflict *string
|
Insert
|
||||||
|
Returning(field Field) First
|
||||||
table string
|
OnConflict(fields ...Field) Do
|
||||||
conflictAction string
|
Execute
|
||||||
|
Stringer
|
||||||
fields []string
|
|
||||||
vals []string
|
|
||||||
args []any
|
|
||||||
debug bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Table) Insert() Insert {
|
|
||||||
qb := &insertQry{
|
|
||||||
table: t.Name,
|
|
||||||
fields: make([]string, 0, t.FieldCount),
|
|
||||||
vals: make([]string, 0, t.FieldCount),
|
|
||||||
args: make([]any, 0, t.FieldCount),
|
|
||||||
debug: t.debug,
|
|
||||||
}
|
}
|
||||||
return qb
|
|
||||||
}
|
Insert interface {
|
||||||
|
Set(field Field, val any) InsertClause
|
||||||
|
SetMap(fields map[Field]any) InsertClause
|
||||||
|
}
|
||||||
|
|
||||||
|
insertQry struct {
|
||||||
|
returing *string
|
||||||
|
onConflict *string
|
||||||
|
|
||||||
|
table string
|
||||||
|
conflictAction string
|
||||||
|
|
||||||
|
fields []string
|
||||||
|
vals []string
|
||||||
|
args []any
|
||||||
|
debug bool
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func (q *insertQry) Set(field Field, val any) InsertClause {
|
func (q *insertQry) Set(field Field, val any) InsertClause {
|
||||||
q.fields = append(q.fields, field.Name())
|
q.fields = append(q.fields, field.Name())
|
||||||
@@ -51,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) Returning(field Field) First {
|
func (q *insertQry) Returning(field Field) First {
|
||||||
col := field.Name()
|
col := field.String()
|
||||||
q.returing = &col
|
q.returing = &col
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
@@ -80,21 +84,28 @@ func (q *insertQry) DoNothing() Execute {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
||||||
var sb strings.Builder
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
|
||||||
|
sb.WriteString("DO UPDATE SET ")
|
||||||
for i, f := range fields {
|
for i, f := range fields {
|
||||||
col := f.Name()
|
col := f.Name()
|
||||||
if i == 0 {
|
if i > 0 {
|
||||||
fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col)
|
sb.WriteString(", ")
|
||||||
} else {
|
|
||||||
fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col)
|
|
||||||
}
|
}
|
||||||
|
sb.WriteString(col)
|
||||||
|
sb.WriteString(" = EXCLUDED.")
|
||||||
|
sb.WriteString(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
q.conflictAction = "DO UPDATE SET " + sb.String()
|
q.conflictAction = sb.String()
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) Exec(ctx context.Context) error {
|
func (q *insertQry) Exec(ctx context.Context) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before Exec()")
|
||||||
|
}
|
||||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -103,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before ExecTx()")
|
||||||
|
}
|
||||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -112,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before First()")
|
||||||
|
}
|
||||||
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before FirstTx()")
|
||||||
|
}
|
||||||
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
245
qry_select.go
245
qry_select.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -14,7 +15,128 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
SelectClause interface {
|
||||||
|
// Join and Inner Join are same
|
||||||
|
Join(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
LeftJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
RightJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
FullJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
CrossJoin(m Table) SelectClause
|
||||||
|
WhereClause
|
||||||
|
OrderByClause
|
||||||
|
GroupByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
raw(prefixArgs []any) (string, []any)
|
||||||
|
}
|
||||||
|
|
||||||
|
WhereClause interface {
|
||||||
|
Where(cond ...Conditioner) AfterWhere
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterWhere interface {
|
||||||
|
WhereClause
|
||||||
|
GroupByClause
|
||||||
|
OrderByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
GroupByClause interface {
|
||||||
|
GroupBy(fields ...Field) AfterGroupBy
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterGroupBy interface {
|
||||||
|
HavinClause
|
||||||
|
OrderByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
HavinClause interface {
|
||||||
|
Having(cond ...Conditioner) AfterHaving
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterHaving interface {
|
||||||
|
OrderByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
OrderByClause interface {
|
||||||
|
OrderBy(fields ...Field) AfterOrderBy
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterOrderBy interface {
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
LimitClause interface {
|
||||||
|
Limit(v int) AfterLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterLimit interface {
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
OffsetClause interface {
|
||||||
|
Offset(v int) AfterOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterOffset interface {
|
||||||
|
LimitClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
Do interface {
|
||||||
|
DoNothing() Execute
|
||||||
|
DoUpdate(fields ...Field) Execute
|
||||||
|
}
|
||||||
|
|
||||||
|
Query interface {
|
||||||
|
First
|
||||||
|
All
|
||||||
|
Stringer
|
||||||
|
Bulder
|
||||||
|
}
|
||||||
|
|
||||||
|
RowScanner interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
RowsCb func(row RowScanner) error
|
||||||
|
|
||||||
|
First interface {
|
||||||
|
First(ctx context.Context, dest ...any) error
|
||||||
|
FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error
|
||||||
|
Stringer
|
||||||
|
}
|
||||||
|
|
||||||
|
All interface {
|
||||||
|
// Query rows
|
||||||
|
//
|
||||||
|
// don't forget to close() rows
|
||||||
|
All(ctx context.Context, rows RowsCb) error
|
||||||
|
// Query rows
|
||||||
|
//
|
||||||
|
// don't forget to close() rows
|
||||||
|
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
||||||
|
}
|
||||||
|
|
||||||
|
Bulder interface {
|
||||||
|
Build(needArgs bool) (qry string, args []any)
|
||||||
|
}
|
||||||
|
|
||||||
selectQry struct {
|
selectQry struct {
|
||||||
|
textSearch *textSearchCTE
|
||||||
|
|
||||||
table string
|
table string
|
||||||
fields []Field
|
fields []Field
|
||||||
args []any
|
args []any
|
||||||
@@ -23,9 +145,11 @@ type (
|
|||||||
groupBy []Field
|
groupBy []Field
|
||||||
having []Conditioner
|
having []Conditioner
|
||||||
orderBy []Field
|
orderBy []Field
|
||||||
limit int
|
|
||||||
offset int
|
limit int
|
||||||
debug bool
|
offset int
|
||||||
|
|
||||||
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
CondAction uint8
|
CondAction uint8
|
||||||
@@ -51,34 +175,47 @@ const (
|
|||||||
CondActionSubQuery
|
CondActionSubQuery
|
||||||
)
|
)
|
||||||
|
|
||||||
// Select clause
|
func (q *selectQry) Join(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
func (t Table) Select(field ...Field) SelectClause {
|
return q.buildJoin(t, "JOIN", t1Field, t2Field, cond...)
|
||||||
qb := &selectQry{
|
}
|
||||||
table: t.Name,
|
|
||||||
debug: t.debug,
|
func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
fields: field,
|
return q.buildJoin(t, "LEFT JOIN", t1Field, t2Field, cond...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
|
return q.buildJoin(t, "RIGHT JOIN", t1Field, t2Field, cond...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
|
return q.buildJoin(t, "FULL JOIN", t1Field, t2Field, cond...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
|
str := joinKW + " " + t.Name + " ON " + t1Field.String() + " = " + t2Field.String()
|
||||||
|
if len(cond) == 0 { // Join with no condition
|
||||||
|
q.join = append(q.join, str)
|
||||||
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
return qb
|
// Join has condition(s)
|
||||||
}
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
sb.Grow(len(str) * 2)
|
||||||
|
|
||||||
func (q *selectQry) Join(t Table, t1Field, t2Field Field) SelectClause {
|
sb.WriteString(str)
|
||||||
q.join = append(q.join, "JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
sb.WriteString(" AND ")
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field) SelectClause {
|
var argIdx int
|
||||||
q.join = append(q.join, "LEFT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
for i, c := range cond {
|
||||||
return q
|
argIdx = len(q.args)
|
||||||
}
|
if i > 0 {
|
||||||
|
sb.WriteString(" AND ")
|
||||||
|
}
|
||||||
|
sb.WriteString(c.Condition(&q.args, argIdx))
|
||||||
|
}
|
||||||
|
|
||||||
func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field) SelectClause {
|
q.join = append(q.join, sb.String())
|
||||||
q.join = append(q.join, "RIGHT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field) SelectClause {
|
|
||||||
q.join = append(q.join, "FULL JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
|||||||
|
|
||||||
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||||
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if err != nil {
|
||||||
return ErrNoRows
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -138,13 +279,21 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for errors from iteration
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||||
rows, err := tx.Query(ctx, q.String(), q.args...)
|
rows, err := tx.Query(ctx, q.String(), q.args...)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if err != nil {
|
||||||
return ErrNoRows
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@@ -154,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for errors from iteration
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,11 +317,28 @@ func (q *selectQry) raw(prefixArgs []any) (string, []any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) String() string {
|
func (q *selectQry) String() string {
|
||||||
|
qry, _ := q.Build(false)
|
||||||
|
if q.debug {
|
||||||
|
fmt.Println("***")
|
||||||
|
fmt.Println(qry)
|
||||||
|
fmt.Printf("%+v\n", q.args)
|
||||||
|
fmt.Println("***")
|
||||||
|
}
|
||||||
|
return qry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
|
||||||
sb := getSB()
|
sb := getSB()
|
||||||
defer putSB(sb)
|
defer putSB(sb)
|
||||||
|
|
||||||
sb.Grow(q.averageLen())
|
sb.Grow(q.averageLen())
|
||||||
|
|
||||||
|
if q.textSearch != nil {
|
||||||
|
var ts = q.textSearch
|
||||||
|
q.args = slices.Insert(q.args, 0, any(ts.value))
|
||||||
|
sb.WriteString("WITH " + ts.name + " AS (SELECT to_tsquery('english', $1) AS " + ts.alias + ") ")
|
||||||
|
}
|
||||||
|
|
||||||
// SELECT
|
// SELECT
|
||||||
sb.WriteString("SELECT ")
|
sb.WriteString("SELECT ")
|
||||||
sb.WriteString(joinFileds(q.fields))
|
sb.WriteString(joinFileds(q.fields))
|
||||||
@@ -178,6 +349,11 @@ func (q *selectQry) String() string {
|
|||||||
sb.WriteString(" " + strings.Join(q.join, " "))
|
sb.WriteString(" " + strings.Join(q.join, " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search Query Cross join
|
||||||
|
if q.textSearch != nil {
|
||||||
|
sb.WriteString(" CROSS JOIN " + q.textSearch.name)
|
||||||
|
}
|
||||||
|
|
||||||
// WHERE
|
// WHERE
|
||||||
if len(q.where) > 0 {
|
if len(q.where) > 0 {
|
||||||
sb.WriteString(" WHERE ")
|
sb.WriteString(" WHERE ")
|
||||||
@@ -228,14 +404,19 @@ func (q *selectQry) String() string {
|
|||||||
sb.WriteString(strconv.Itoa(q.offset))
|
sb.WriteString(strconv.Itoa(q.offset))
|
||||||
}
|
}
|
||||||
|
|
||||||
qry := sb.String()
|
qry = sb.String()
|
||||||
if q.debug {
|
if q.debug {
|
||||||
fmt.Println("***")
|
fmt.Println("***")
|
||||||
fmt.Println(qry)
|
fmt.Println(qry)
|
||||||
fmt.Printf("%+v\n", q.args)
|
fmt.Printf("%+v\n", q.args)
|
||||||
fmt.Println("***")
|
fmt.Println("***")
|
||||||
}
|
}
|
||||||
return qry
|
|
||||||
|
if needArgs {
|
||||||
|
args = slices.Clone(q.args)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) averageLen() int {
|
func (q *selectQry) averageLen() int {
|
||||||
|
|||||||
@@ -5,29 +5,37 @@ package pgm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type updateQry struct {
|
type (
|
||||||
table string
|
Update interface {
|
||||||
cols []string
|
Set(field Field, val any) UpdateClause
|
||||||
condition []Conditioner
|
SetMap(fields map[Field]any) UpdateClause
|
||||||
args []any
|
|
||||||
debug bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Table) Update() Update {
|
|
||||||
qb := &updateQry{
|
|
||||||
table: t.Name,
|
|
||||||
debug: t.debug,
|
|
||||||
cols: make([]string, 0, t.FieldCount),
|
|
||||||
args: make([]any, 0, t.FieldCount),
|
|
||||||
}
|
}
|
||||||
return qb
|
|
||||||
}
|
UpdateClause interface {
|
||||||
|
Update
|
||||||
|
Where(cond ...Conditioner) WhereOrExec
|
||||||
|
}
|
||||||
|
|
||||||
|
WhereOrExec interface {
|
||||||
|
Where(cond ...Conditioner) WhereOrExec
|
||||||
|
Execute
|
||||||
|
}
|
||||||
|
|
||||||
|
updateQry struct {
|
||||||
|
table string
|
||||||
|
cols []string
|
||||||
|
condition []Conditioner
|
||||||
|
args []any
|
||||||
|
debug bool
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func (q *updateQry) Set(field Field, val any) UpdateClause {
|
func (q *updateQry) Set(field Field, val any) UpdateClause {
|
||||||
col := field.Name()
|
col := field.Name()
|
||||||
@@ -49,6 +57,9 @@ func (q *updateQry) Where(cond ...Conditioner) WhereOrExec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *updateQry) Exec(ctx context.Context) error {
|
func (q *updateQry) Exec(ctx context.Context) error {
|
||||||
|
if len(q.cols) == 0 {
|
||||||
|
return errors.New("update query has no columns to update: call Set() before Exec()")
|
||||||
|
}
|
||||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -58,6 +69,9 @@ func (q *updateQry) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *updateQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
func (q *updateQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
if len(q.cols) == 0 {
|
||||||
|
return errors.New("update query has no columns to update: call Set() before ExecTx()")
|
||||||
|
}
|
||||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user