Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d8c22d781 | |||
| c2cf7ff088 | |||
| cc7e6b7b3f | |||
| 48f1d1952e | |||
| 1d9d9d9308 | |||
| 29cddb6389 | |||
| 551e2123bc | |||
| a2b984c342 | |||
| a795c0e8d6 | |||
| bb6a45732f | |||
| 2551e07b3e | |||
| 9837fb1e37 | |||
| 12d6fface6 | |||
| 325103e8ef | |||
| 6f5748d3d3 | |||
| b25f9367ed | |||
| 8750f3ad95 | |||
| ad1faf2056 | |||
| 525c64e678 | |||
| 5f0fdadb8b | |||
| 68263895f7 | |||
| ee6cb445ab | |||
| d07c25fe01 | |||
| 096480a3eb | |||
| 6c14441591 | |||
| d95eea6636 | |||
| 63b71692b5 | |||
| 36e4145365 | |||
| 2ec328059f | |||
| f700f3e891 | |||
| f5350292fc |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -25,4 +25,4 @@ go.work.sum
|
||||
# env file
|
||||
.env
|
||||
|
||||
example/local_*
|
||||
playground/local_*
|
||||
|
||||
314
CLAUDE.md
Normal file
314
CLAUDE.md
Normal file
@@ -0,0 +1,314 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**pgm** is a lightweight, type-safe PostgreSQL query builder for Go, built on top of jackc/pgx. It generates Go code from SQL schema files, enabling compile-time safety for database queries without the overhead of traditional ORMs.
|
||||
|
||||
**Core Philosophy:**
|
||||
- Schema defined in SQL (not Go)
|
||||
- Minimal code generation (only table/column definitions)
|
||||
- Users provide their own models (no forced abstractions)
|
||||
- Type-safe query building with fluent API
|
||||
- Zero reflection for maximum performance
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Building
|
||||
|
||||
```bash
|
||||
# Build CLI with version from git tags
|
||||
make build
|
||||
|
||||
# Build with specific version
|
||||
make build VERSION=v1.2.3
|
||||
|
||||
# Install to GOPATH/bin
|
||||
make install
|
||||
|
||||
# Check current version
|
||||
make version
|
||||
pgm -version
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
# Run all tests in playground
|
||||
make test
|
||||
|
||||
# Run benchmarks for SELECT queries
|
||||
make bench-select
|
||||
|
||||
# Run code generator on playground schema
|
||||
make run
|
||||
```
|
||||
|
||||
### Code Generation
|
||||
|
||||
```bash
|
||||
# Generate Go code from SQL schema
|
||||
pgm -o ./db ./schema.sql
|
||||
|
||||
# The generator expects a SQL schema file and outputs:
|
||||
# - One package per table in the output directory
|
||||
# - Each package contains table and column definitions
|
||||
# Example output: db/users/users.go, db/posts/posts.go
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
#### 1. Query Builder System (qry_*.go files)
|
||||
|
||||
The query builder uses a **mutable, stateful design** for conditional query building:
|
||||
|
||||
- **qry_select.go**: SELECT queries with joins, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET
|
||||
- **qry_insert.go**: INSERT queries with RETURNING and UPSERT (ON CONFLICT) support
|
||||
- **qry_update.go**: UPDATE queries with conditional WHERE clauses
|
||||
- **qry_delete.go**: DELETE queries with WHERE conditions
|
||||
|
||||
**Important:** Query builders accumulate state with each method call. They are designed for conditional building within a single query but should NOT be reused across multiple separate queries.
|
||||
|
||||
```go
|
||||
// ✅ CORRECT - Conditional building
|
||||
query := users.User.Select(users.ID, users.Email)
|
||||
if nameFilter != "" {
|
||||
query = query.Where(users.Name.Like("%" + nameFilter + "%"))
|
||||
}
|
||||
err := query.First(ctx, &id, &email)
|
||||
|
||||
// ❌ WRONG - Reusing builder across separate queries
|
||||
baseQuery := users.User.Select(users.ID)
|
||||
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1) // Adds ID=1
|
||||
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2) // Has BOTH conditions!
|
||||
```
|
||||
|
||||
Query builders are NOT thread-safe. Each goroutine must create its own query instance.
|
||||
|
||||
#### 2. Field System (pgm_field.go)
|
||||
|
||||
Fields are type-safe column references with:
|
||||
- Comparison operators: `Eq()`, `NotEq()`, `Gt()`, `Lt()`, `Gte()`, `Lte()`
|
||||
- Pattern matching: `Like()`, `ILike()`, `LikeFold()`, `EqFold()`
|
||||
- NULL checks: `IsNull()`, `IsNotNull()`
|
||||
- Array operations: `Any()`, `NotAny()`
|
||||
- Aggregate functions: `Count()`, `Sum()`, `Avg()`, `Min()`, `Max()`
|
||||
- String functions: `Lower()`, `Upper()`, `Trim()`, `StringEscape()`
|
||||
- Special functions: `ConcatWs()`, `StringAgg()`, `DateTrunc()`, `RowNumber()`
|
||||
|
||||
**Security:** Functions like `ConcatWs()`, `StringAgg()`, and `DateTrunc()` validate inputs and escape SQL strings. They use allowlists for parameters like date truncation levels.
|
||||
|
||||
#### 3. Connection Pool (pgm.go)
|
||||
|
||||
Global connection pool using pgxpool with atomic pointer for thread safety:
|
||||
|
||||
- `InitPool(Config)`: Initialize once at startup (panics if called multiple times or with invalid config)
|
||||
- `GetPool()`: Retrieve pool instance (panics if not initialized)
|
||||
- `ClosePool()`: Graceful shutdown
|
||||
- `BeginTx(ctx)`: Start transactions
|
||||
|
||||
The pool is stored in an `atomic.Pointer[pgxpool.Pool]` for lock-free concurrent access.
|
||||
|
||||
#### 4. Code Generator (cmd/)
|
||||
|
||||
Parses SQL schema files and generates Go code:
|
||||
|
||||
- **cmd/main.go**: CLI entry point with version flag support
|
||||
- **cmd/parse.go**: Regex-based SQL parser (has known limitations)
|
||||
- **cmd/generate.go**: Code generation with go/format for proper formatting
|
||||
- **cmd/version.go**: Version string generation from build-time ldflags
|
||||
|
||||
**Parser Limitations:**
|
||||
- No multi-line comments (`/* */`)
|
||||
- Limited support for complex data types (arrays, JSON, JSONB)
|
||||
- No advanced PostgreSQL features (PARTITION BY, INHERITS)
|
||||
- Some constraints (CHECK, EXCLUDE) not parsed
|
||||
|
||||
Generated files include a header comment with version and timestamp:
|
||||
```go
|
||||
// Code generated by code.patial.tech/go/pgm/cmd v1.2.3 on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
```
|
||||
|
||||
#### 5. Table System (pgm_table.go)
|
||||
|
||||
Minimal table metadata:
|
||||
```go
|
||||
type Table struct {
|
||||
Name string // Table name
|
||||
FieldCount int // Number of columns
|
||||
PK []string // Primary key columns
|
||||
DerivedTable Query // For subqueries/CTEs
|
||||
}
|
||||
```
|
||||
|
||||
Provides factory methods: `Select()`, `Insert()`, `Update()`, `Delete()`
|
||||
|
||||
### String Builder Pool
|
||||
|
||||
Uses `sync.Pool` for efficient string building (qry.go):
|
||||
```go
|
||||
var sbPool = sync.Pool{
|
||||
New: func() any { return new(strings.Builder) }
|
||||
}
|
||||
```
|
||||
|
||||
All query builders use `getSB()` / `putSB()` to reduce allocations.
|
||||
|
||||
### Error Handling
|
||||
|
||||
Custom errors in pgm.go:
|
||||
- `ErrConnStringMissing`: Connection string validation
|
||||
- `ErrInitTX`: Transaction initialization failure
|
||||
- `ErrCommitTX`: Transaction commit failure
|
||||
- `ErrNoRows`: Wrapper for pgx.ErrNoRows
|
||||
|
||||
Use `pgm.IsNotFound(err)` to check for no rows errors.
|
||||
|
||||
### Generated Code Structure
|
||||
|
||||
For a schema with `users` and `posts` tables:
|
||||
```
|
||||
db/
|
||||
├── schema.go # Table definitions and DerivedTable helper
|
||||
├── user/
|
||||
│ └── users.go # User table columns as constants
|
||||
├── post/
|
||||
│ └── posts.go # Post table columns as constants
|
||||
└── ...
|
||||
```
|
||||
|
||||
Each table file exports constants like:
|
||||
```go
|
||||
const (
|
||||
All pgm.Field = "users.*"
|
||||
ID pgm.Field = "users.id"
|
||||
Email pgm.Field = "users.email"
|
||||
// ... all columns
|
||||
)
|
||||
```
|
||||
|
||||
### Naming Conventions
|
||||
|
||||
- **Table pluralization**: `pluralToSingular()` in cmd/generate.go handles irregular plurals (people→person, children→child) and common patterns (-ies→-y, -ves→-fe, -es→-e, -s→"")
|
||||
- **Field naming**: Snake_case columns converted to PascalCase (first_name → FirstName)
|
||||
- **ID suffix**: Fields ending in `_id` become `ID` not `Id` (user_id → UserID)
|
||||
|
||||
## Key Implementation Details
|
||||
|
||||
### Query Building Strategy
|
||||
|
||||
1. **Pre-allocation**: Queries estimate final string length via `averageLen()` methods to reduce allocations
|
||||
2. **Conditional accumulation**: All clauses stored in slices, built on demand
|
||||
3. **Parameterized queries**: Uses PostgreSQL numbered parameters (`$1`, `$2`, etc.)
|
||||
4. **Builder methods return interfaces**: Enforces correct method call sequences at compile time
|
||||
|
||||
Example type progression:
|
||||
```go
|
||||
SelectClause → WhereClause → AfterWhere → OrderByClause → Query → First/All
|
||||
```
|
||||
|
||||
### Transaction Handling
|
||||
|
||||
Pattern used throughout:
|
||||
```go
|
||||
tx, err := pgm.BeginTx(ctx)
|
||||
if err != nil { return err }
|
||||
defer tx.Rollback(ctx) // Safe to call after commit
|
||||
|
||||
// ... operations using ExecTx, FirstTx, AllTx methods ...
|
||||
|
||||
return tx.Commit(ctx)
|
||||
```
|
||||
|
||||
All query execution methods have `Tx` variants that accept `pgx.Tx`.
|
||||
|
||||
### Full-Text Search Helpers
|
||||
|
||||
Helper functions for PostgreSQL tsvector queries:
|
||||
- `TsAndQuery()`: AND operator between terms
|
||||
- `TsPrefixAndQuery()`: AND with prefix matching `:*`
|
||||
- `TsOrQuery()`: OR operator
|
||||
- `TsPrefixOrQuery()`: OR with prefix matching
|
||||
|
||||
Used with tsvector columns via `Field.TsQuery()`.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests are in the `playground/` directory:
|
||||
- **playground/schema.sql**: Test database schema
|
||||
- **playground/db/**: Generated code from schema
|
||||
- **playground/*_test.go**: Integration tests for SELECT, INSERT, UPDATE, DELETE
|
||||
- **playground/local_select_test.go**: Additional SELECT test cases
|
||||
|
||||
Tests require a running PostgreSQL instance. The playground uses the generated code to verify the query builder works correctly.
|
||||
|
||||
## Version Management
|
||||
|
||||
Version is injected at build time via ldflags:
|
||||
```bash
|
||||
go build -ldflags "-X main.version=v1.2.3" ./cmd
|
||||
```
|
||||
|
||||
The Makefile automatically extracts version from git tags:
|
||||
```makefile
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
```
|
||||
|
||||
This version appears in:
|
||||
- CLI `pgm -version` output
|
||||
- Generated file headers
|
||||
- Version string formatting in cmd/version.go
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Connection Pool Initialization
|
||||
Always initialize once at startup and defer cleanup:
|
||||
```go
|
||||
func main() {
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: os.Getenv("DATABASE_URL"),
|
||||
MaxConns: 25,
|
||||
MinConns: 5,
|
||||
})
|
||||
defer pgm.ClosePool()
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### Safe Query Execution
|
||||
Check for no rows using the helper:
|
||||
```go
|
||||
err := users.User.Select(users.Email).Where(users.ID.Eq(id)).First(ctx, &email)
|
||||
if pgm.IsNotFound(err) {
|
||||
// Handle not found case
|
||||
}
|
||||
if err != nil {
|
||||
// Handle other errors
|
||||
}
|
||||
```
|
||||
|
||||
### Validation Requirements
|
||||
- UPDATE requires at least one `Set()` call
|
||||
- INSERT requires at least one `Set()` call
|
||||
- DELETE without `Where()` deletes ALL rows (dangerous!)
|
||||
|
||||
All execution methods validate these requirements and return descriptive errors.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **sync.Pool**: Reuses string builders across queries
|
||||
- **Pre-allocation**: Queries pre-calculate buffer sizes
|
||||
- **Zero reflection**: Direct field access, no runtime type inspection
|
||||
- **pgxpool**: Leverages jackc/pgx's efficient connection pooling
|
||||
- **Direct scanning**: Users scan into their own types, no intermediate mapping
|
||||
|
||||
## Security Notes
|
||||
|
||||
- All queries use parameterized statements (PostgreSQL numbered parameters)
|
||||
- Field validation in functions like `DateTrunc()` uses allowlists
|
||||
- SQL string escaping in `escapeSQLString()` for literal values
|
||||
- Identifier validation via regex in `validateSQLIdentifier()`
|
||||
- Connection pool configuration validates against negative values and invalid ranges
|
||||
25
Makefile
25
Makefile
@@ -1,4 +1,25 @@
|
||||
.PHONY: run bench-select test build install
|
||||
|
||||
# Version can be set via: make build VERSION=v1.2.3
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
|
||||
run:
|
||||
go run ./cmd -o ./example/db ./example/schema.sql
|
||||
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||
|
||||
bench-select:
|
||||
go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||
go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||
|
||||
test:
|
||||
go test ./playground
|
||||
|
||||
# Build with version information
|
||||
build:
|
||||
go build -ldflags "-X main.version=$(VERSION)" -o pgm ./cmd
|
||||
|
||||
# Install to GOPATH/bin with version
|
||||
install:
|
||||
go install -ldflags "-X main.version=$(VERSION)" ./cmd
|
||||
|
||||
# Show current version
|
||||
version:
|
||||
@echo "Version: $(VERSION)"
|
||||
|
||||
795
README.md
795
README.md
@@ -1,5 +1,794 @@
|
||||
# pgm (Postgres Mapper)
|
||||
# pgm - PostgreSQL Query Mapper
|
||||
|
||||
Simple query builder to work with Go:PG apps.
|
||||
[](https://pkg.go.dev/code.patial.tech/go/pgm)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
Will work along side with [dbmate](https://github.com/amacneil/dbmate), will consume schema.sql file created by dbmate
|
||||
A lightweight, type-safe PostgreSQL query builder for Go, built on top of [jackc/pgx](https://github.com/jackc/pgx). **pgm** generates Go code from your SQL schema, enabling you to write SQL queries with compile-time safety and autocompletion support.
|
||||
|
||||
## Features
|
||||
|
||||
- **Type-safe queries** - Column and table names are validated at compile time
|
||||
- **Zero reflection** - Fast performance with no runtime reflection overhead
|
||||
- **SQL schema-based** - Generate Go code directly from your SQL schema files
|
||||
- **Fluent API** - Intuitive query builder with method chaining
|
||||
- **Transaction support** - First-class support for pgx transactions
|
||||
- **Full-text search** - Built-in PostgreSQL full-text search helpers
|
||||
- **Connection pooling** - Leverages pgx connection pool for optimal performance
|
||||
- **Minimal code generation** - Only generates what you need, no bloat
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Why pgm?](#why-pgm)
|
||||
- [Installation](#installation)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Usage Examples](#usage-examples)
|
||||
- [SELECT Queries](#select-queries)
|
||||
- [INSERT Queries](#insert-queries)
|
||||
- [UPDATE Queries](#update-queries)
|
||||
- [DELETE Queries](#delete-queries)
|
||||
- [Joins](#joins)
|
||||
- [Transactions](#transactions)
|
||||
- [Full-Text Search](#full-text-search)
|
||||
- [CLI Tool](#cli-tool)
|
||||
- [API Documentation](#api-documentation)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
|
||||
## Why pgm?
|
||||
|
||||
### The Problem with Existing ORMs
|
||||
|
||||
While Go has excellent ORMs like [ent](https://github.com/ent/ent) and [sqlc](https://github.com/sqlc-dev/sqlc), they come with tradeoffs:
|
||||
|
||||
**ent** - Feature-rich but heavy:
|
||||
- Generates extensive code for features you may never use
|
||||
- Significantly increases binary size
|
||||
- Complex schema definition in Go instead of SQL
|
||||
- Auto-migrations can obscure actual database schema
|
||||
|
||||
**sqlc** - Great tool, but:
|
||||
- Creates separate database models, forcing model mapping
|
||||
- Query results require their own generated types
|
||||
- Less flexibility in dynamic query building
|
||||
|
||||
### The pgm Approach
|
||||
|
||||
**pgm** takes a hybrid approach:
|
||||
|
||||
✅ **Schema as SQL** - Define your database schema in pure SQL, where it belongs
|
||||
✅ **Minimal generation** - Only generates table and column definitions
|
||||
✅ **Your models** - Use your own application models, no forced abstractions
|
||||
✅ **Type safety** - Catch schema changes at compile time
|
||||
✅ **SQL power** - Full control over your queries with a fluent API
|
||||
✅ **Migration-friendly** - Use mature tools like [dbmate](https://github.com/amacneil/dbmate) for migrations
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get code.patial.tech/go/pgm
|
||||
```
|
||||
|
||||
Install the CLI tool for schema code generation:
|
||||
|
||||
```bash
|
||||
go install code.patial.tech/go/pgm/cmd@latest
|
||||
```
|
||||
|
||||
### Building from Source
|
||||
|
||||
Build with automatic version detection (uses git tags):
|
||||
|
||||
```bash
|
||||
# Build with version from git tags
|
||||
make build
|
||||
|
||||
# Build with specific version
|
||||
make build VERSION=v1.2.3
|
||||
|
||||
# Install to GOPATH/bin
|
||||
make install
|
||||
|
||||
# Or build manually with version
|
||||
go build -ldflags "-X main.version=v1.2.3" -o pgm ./cmd
|
||||
```
|
||||
|
||||
Check the version:
|
||||
|
||||
```bash
|
||||
pgm -version
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Create Your Schema
|
||||
|
||||
Create a SQL schema file `schema.sql` or use the one created by [dbmate](https://github.com/amacneil/dbmate):
|
||||
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE posts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id),
|
||||
title VARCHAR(500) NOT NULL,
|
||||
content TEXT,
|
||||
published BOOLEAN DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
### 2. Generate Go Code
|
||||
|
||||
Run the pgm CLI tool:
|
||||
|
||||
```bash
|
||||
pgm -o ./db ./schema.sql
|
||||
```
|
||||
|
||||
This generates Go files for each table in `./db/`:
|
||||
- `db/users/users.go` - Table and column definitions for users
|
||||
- `db/posts/posts.go` - Table and column definitions for posts
|
||||
|
||||
### 3. Use in Your Code
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"code.patial.tech/go/pgm"
|
||||
"yourapp/db/users"
|
||||
"yourapp/db/posts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize connection pool
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: "postgres://user:pass@localhost:5432/dbname",
|
||||
MaxConns: 25,
|
||||
MinConns: 5,
|
||||
})
|
||||
defer pgm.ClosePool() // Ensure graceful shutdown
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Query a user
|
||||
var email string
|
||||
err := users.User.Select(users.Email).
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
First(ctx, &email)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Printf("User email: %s", email)
|
||||
}
|
||||
```
|
||||
|
||||
## Important: Query Builder Lifecycle
|
||||
|
||||
### ✅ Conditional Building (CORRECT)
|
||||
|
||||
Query builders are **mutable by design** to support conditional query building:
|
||||
|
||||
```go
|
||||
// ✅ CORRECT - Conditional building pattern
|
||||
query := users.User.Select(users.ID, users.Email, users.Name)
|
||||
|
||||
// Add conditions based on filters
|
||||
if nameFilter != "" {
|
||||
query = query.Where(users.Name.Like("%" + nameFilter + "%"))
|
||||
}
|
||||
|
||||
if statusFilter > 0 {
|
||||
query = query.Where(users.Status.Eq(statusFilter))
|
||||
}
|
||||
|
||||
if sortByName {
|
||||
query = query.OrderBy(users.Name.Asc())
|
||||
}
|
||||
|
||||
// Execute the final query with all accumulated conditions
|
||||
err := query.First(ctx, &id, &email, &name)
|
||||
```
|
||||
|
||||
**This is the intended use!** The builder accumulates your conditions, which is powerful and flexible.
|
||||
|
||||
### ❌ Unintentional Reuse (INCORRECT)
|
||||
|
||||
Don't try to create a "base query" and reuse it for **multiple different queries**:
|
||||
|
||||
```go
|
||||
// ❌ WRONG - Trying to reuse for multiple separate queries
|
||||
baseQuery := users.User.Select(users.ID, users.Email)
|
||||
|
||||
// First query - adds ID condition
|
||||
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||
|
||||
// Second query - ALSO has ID=1 from above PLUS Status=2!
|
||||
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||
// This executes: WHERE users.id = 1 AND users.status = 2 (WRONG!)
|
||||
|
||||
// ✅ CORRECT - Each separate query gets its own builder
|
||||
users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||
users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||
```
|
||||
|
||||
**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up.
|
||||
|
||||
### Thread Safety
|
||||
|
||||
⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines:
|
||||
|
||||
```go
|
||||
// ✅ CORRECT - Each goroutine creates its own query
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
var email string
|
||||
err := users.User.Select(users.Email).
|
||||
Where(users.ID.Eq(id)).
|
||||
First(ctx, &email)
|
||||
// Process result...
|
||||
}(i)
|
||||
}
|
||||
|
||||
// ❌ WRONG - Sharing query builder across goroutines
|
||||
baseQuery := users.User.Select(users.Email)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
var email string
|
||||
baseQuery.Where(users.ID.Eq(id)).First(ctx, &email)
|
||||
// RACE CONDITION! Multiple goroutines modifying shared state
|
||||
}(i)
|
||||
}
|
||||
```
|
||||
|
||||
**Thread-Safe Components:**
|
||||
- ✅ Connection Pool - Safe for concurrent use
|
||||
- ✅ Table objects - Safe to share
|
||||
- ❌ Query builders - Create new instance per goroutine
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### SELECT Queries
|
||||
|
||||
#### Basic Select
|
||||
|
||||
```go
|
||||
var user struct {
|
||||
ID string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||
Where(users.Email.Eq("john@example.com")).
|
||||
First(ctx, &user.ID, &user.Email, &user.Name)
|
||||
```
|
||||
|
||||
#### Select with Multiple Conditions
|
||||
|
||||
```go
|
||||
err := users.User.Select(users.ID, users.Email).
|
||||
Where(
|
||||
users.Email.Like("john%"),
|
||||
users.CreatedAt.Gt(time.Now().AddDate(0, -1, 0)),
|
||||
).
|
||||
OrderBy(users.CreatedAt.Desc()).
|
||||
Limit(10).
|
||||
First(ctx, &user.ID, &user.Email)
|
||||
```
|
||||
|
||||
#### Select All with Callback
|
||||
|
||||
```go
|
||||
var userList []User
|
||||
|
||||
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||
Where(users.Name.Like("J%")).
|
||||
OrderBy(users.Name.Asc()).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var u User
|
||||
if err := row.Scan(&u.ID, &u.Email, &u.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
userList = append(userList, u)
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
#### Pagination
|
||||
|
||||
```go
|
||||
page := 2
|
||||
pageSize := 20
|
||||
|
||||
err := users.User.Select(users.ID, users.Email).
|
||||
OrderBy(users.CreatedAt.Desc()).
|
||||
Limit(pageSize).
|
||||
Offset((page - 1) * pageSize).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
// Process rows
|
||||
})
|
||||
```
|
||||
|
||||
#### Grouping and Having
|
||||
|
||||
```go
|
||||
err := posts.Post.Select(posts.UserID, pgm.Count(posts.ID)).
|
||||
GroupBy(posts.UserID).
|
||||
Having(pgm.Count(posts.ID).Gt(5)).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var userID string
|
||||
var postCount int
|
||||
return row.Scan(&userID, &postCount)
|
||||
})
|
||||
```
|
||||
|
||||
### INSERT Queries
|
||||
|
||||
#### Simple Insert
|
||||
|
||||
```go
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
Set(users.CreatedAt, pgm.PgTimeNow()).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Insert with Map
|
||||
|
||||
```go
|
||||
data := map[pgm.Field]any{
|
||||
users.Email: "jane@example.com",
|
||||
users.Name: "Jane Doe",
|
||||
users.CreatedAt: pgm.PgTimeNow(),
|
||||
}
|
||||
|
||||
err := users.User.Insert().
|
||||
SetMap(data).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Insert with RETURNING
|
||||
|
||||
```go
|
||||
var newID string
|
||||
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
Returning(users.ID).
|
||||
First(ctx, &newID)
|
||||
```
|
||||
|
||||
#### Upsert (INSERT ... ON CONFLICT)
|
||||
|
||||
```go
|
||||
// Do nothing on conflict
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
OnConflict(users.Email).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
|
||||
// Update on conflict
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe Updated").
|
||||
OnConflict(users.Email).
|
||||
DoUpdate(users.Name).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### UPDATE Queries
|
||||
|
||||
#### Simple Update
|
||||
|
||||
```go
|
||||
err := users.User.Update().
|
||||
Set(users.Name, "John Smith").
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Update Multiple Fields
|
||||
|
||||
```go
|
||||
updates := map[pgm.Field]any{
|
||||
users.Name: "John Smith",
|
||||
users.Email: "john.smith@example.com",
|
||||
}
|
||||
|
||||
err := users.User.Update().
|
||||
SetMap(updates).
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Conditional Update
|
||||
|
||||
```go
|
||||
err := users.User.Update().
|
||||
Set(users.Name, "Updated Name").
|
||||
Where(
|
||||
users.Email.Like("%@example.com"),
|
||||
users.CreatedAt.Lt(time.Now().AddDate(-1, 0, 0)),
|
||||
).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### DELETE Queries
|
||||
|
||||
#### Simple Delete
|
||||
|
||||
```go
|
||||
err := users.User.Delete().
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Conditional Delete
|
||||
|
||||
```go
|
||||
err := posts.Post.Delete().
|
||||
Where(
|
||||
posts.Published.Eq(false),
|
||||
posts.CreatedAt.Lt(time.Now().AddDate(0, 0, -30)),
|
||||
).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### Joins
|
||||
|
||||
#### Inner Join
|
||||
|
||||
```go
|
||||
err := posts.Post.Select(posts.Title, users.Name).
|
||||
Join(users.User, posts.UserID, users.ID).
|
||||
Where(users.Email.Eq("john@example.com")).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var title, userName string
|
||||
return row.Scan(&title, &userName)
|
||||
})
|
||||
```
|
||||
|
||||
#### Left Join
|
||||
|
||||
```go
|
||||
err := users.User.Select(users.Name, posts.Title).
|
||||
LeftJoin(posts.Post, users.ID, posts.UserID).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var userName, postTitle string
|
||||
return row.Scan(&userName, &postTitle)
|
||||
})
|
||||
```
|
||||
|
||||
#### Join with Additional Conditions
|
||||
|
||||
```go
|
||||
err := posts.Post.Select(posts.Title, users.Name).
|
||||
Join(users.User, posts.UserID, users.ID, users.Email.Like("%@example.com")).
|
||||
Where(posts.Published.Eq(true)).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
// Process rows
|
||||
})
|
||||
```
|
||||
|
||||
### Transactions
|
||||
|
||||
#### Basic Transaction
|
||||
|
||||
```go
|
||||
tx, err := pgm.BeginTx(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Insert user
|
||||
var userID string
|
||||
err = users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
Returning(users.ID).
|
||||
FirstTx(ctx, tx, &userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert post
|
||||
err = posts.Post.Insert().
|
||||
Set(posts.UserID, userID).
|
||||
Set(posts.Title, "My First Post").
|
||||
Set(posts.Content, "Hello, World!").
|
||||
ExecTx(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
PostgreSQL full-text search helpers:
|
||||
|
||||
```go
|
||||
// Search with AND operator (all terms must match)
|
||||
searchQuery := pgm.TsAndQuery("golang database")
|
||||
// Result: "golang & database"
|
||||
|
||||
// Search with prefix matching
|
||||
searchQuery := pgm.TsPrefixAndQuery("gol data")
|
||||
// Result: "gol:* & data:*"
|
||||
|
||||
// Search with OR operator (any term matches)
|
||||
searchQuery := pgm.TsOrQuery("golang rust")
|
||||
// Result: "golang | rust"
|
||||
|
||||
// Prefix OR search
|
||||
searchQuery := pgm.TsPrefixOrQuery("go ru")
|
||||
// Result: "go:* | ru:*"
|
||||
|
||||
// Use in query (assuming you have a tsvector column)
|
||||
err := posts.Post.Select(posts.Title, posts.Content).
|
||||
Where(posts.SearchVector.Match(pgm.TsPrefixAndQuery(searchTerm))).
|
||||
OrderBy(posts.CreatedAt.Desc()).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
// Process results
|
||||
})
|
||||
```
|
||||
|
||||
## CLI Tool
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
pgm -o <output_directory> <schema.sql>
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
```bash
|
||||
-o string Output directory path (required)
|
||||
-version Show version information
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
```bash
|
||||
# Generate from a single schema file
|
||||
pgm -o ./db ./schema.sql
|
||||
|
||||
# Generate from concatenated migrations
|
||||
cat migrations/*.sql > /tmp/schema.sql && pgm -o ./db /tmp/schema.sql
|
||||
|
||||
# Check version
|
||||
pgm -version
|
||||
```
|
||||
|
||||
### Known Limitations
|
||||
|
||||
The CLI tool uses a regex-based SQL parser with the following limitations:
|
||||
|
||||
- ❌ Multi-line comments `/* */` are not supported
|
||||
- ❌ Complex data types (arrays, JSON, JSONB) may not parse correctly
|
||||
- ❌ Quoted identifiers with special characters may fail
|
||||
- ❌ Advanced PostgreSQL features (PARTITION BY, INHERITS) not supported
|
||||
- ❌ Some constraints (CHECK, EXCLUDE) are not parsed
|
||||
|
||||
**Workarounds:**
|
||||
- Use simple CREATE TABLE statements
|
||||
- Avoid complex PostgreSQL-specific syntax in schema files
|
||||
- Split complex schemas into multiple simple statements
|
||||
- Remove comments before running the generator
|
||||
|
||||
For complex schemas, consider contributing a more robust parser or using a proper SQL parser library.
|
||||
|
||||
### Generated Code Structure
|
||||
|
||||
For a table named `users`, pgm generates:
|
||||
|
||||
```
|
||||
db/
|
||||
└── users/
|
||||
└── users.go
|
||||
```
|
||||
|
||||
The generated file contains:
|
||||
- Generated code header with version and timestamp
|
||||
- Table definition (`User`)
|
||||
- Column field definitions (`ID`, `Email`, `Name`, etc.)
|
||||
- Type-safe query builders (`Select()`, `Insert()`, `Update()`, `Delete()`)
|
||||
|
||||
**Example header:**
|
||||
```go
|
||||
// Code generated by code.patial.tech/go/pgm/cmd v1.2.3 on 2025-01-27 15:04:05 DO NOT EDIT.
|
||||
```
|
||||
|
||||
The version in generated files helps track which version of the CLI tool was used, making it easier to identify when regeneration is needed after upgrades.
|
||||
|
||||
## API Documentation
|
||||
|
||||
### Connection Pool
|
||||
|
||||
#### InitPool
|
||||
|
||||
Initialize the connection pool (must be called once at startup):
|
||||
|
||||
```go
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: "postgres://...",
|
||||
MaxConns: 25,
|
||||
MinConns: 5,
|
||||
MaxConnLifetime: time.Hour,
|
||||
MaxConnIdleTime: time.Minute * 30,
|
||||
})
|
||||
```
|
||||
|
||||
**Configuration Validation:**
|
||||
- MinConns cannot be greater than MaxConns
|
||||
- Connection counts cannot be negative
|
||||
- Connection string is required
|
||||
|
||||
#### ClosePool
|
||||
|
||||
Close the connection pool gracefully (call during application shutdown):
|
||||
|
||||
```go
|
||||
func main() {
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: "postgres://...",
|
||||
})
|
||||
defer pgm.ClosePool() // Ensures proper cleanup
|
||||
|
||||
// Your application code
|
||||
}
|
||||
```
|
||||
|
||||
#### GetPool
|
||||
|
||||
Get the underlying pgx pool:
|
||||
|
||||
```go
|
||||
pool := pgm.GetPool()
|
||||
```
|
||||
|
||||
### Query Conditions
|
||||
|
||||
Available condition methods on fields:
|
||||
|
||||
- `Eq(value)` - Equal to
|
||||
- `NotEq(value)` - Not equal to
|
||||
- `Gt(value)` - Greater than
|
||||
- `Gte(value)` - Greater than or equal to
|
||||
- `Lt(value)` - Less than
|
||||
- `Lte(value)` - Less than or equal to
|
||||
- `Like(pattern)` - LIKE pattern match
|
||||
- `ILike(pattern)` - Case-insensitive LIKE
|
||||
- `In(values...)` - IN list
|
||||
- `NotIn(values...)` - NOT IN list
|
||||
- `IsNull()` - IS NULL
|
||||
- `IsNotNull()` - IS NOT NULL
|
||||
- `Between(start, end)` - BETWEEN range
|
||||
|
||||
### Field Methods
|
||||
|
||||
- `Asc()` - Sort ascending
|
||||
- `Desc()` - Sort descending
|
||||
- `Name()` - Get column name
|
||||
- `String()` - Get fully qualified name (table.column)
|
||||
|
||||
### Utilities
|
||||
|
||||
- `pgm.PgTime(t time.Time)` - Convert Go time to PostgreSQL timestamptz
|
||||
- `pgm.PgTimeNow()` - Current time as PostgreSQL timestamptz
|
||||
- `pgm.IsNotFound(err)` - Check if error is "no rows found"
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use transactions for related operations** - Ensure data consistency
|
||||
2. **Define schema in SQL** - Use migration tools like dbmate for schema management
|
||||
3. **Regenerate after schema changes** - Run the CLI tool after any schema modifications
|
||||
4. **Use your own models** - Don't let the database dictate your domain models
|
||||
5. **Handle pgx.ErrNoRows** - Use `pgm.IsNotFound(err)` for cleaner error checking
|
||||
6. **Always use context with timeouts** - Prevent queries from running indefinitely
|
||||
7. **Validate UPDATE queries** - Ensure Set() is called before Exec()
|
||||
8. **Be careful with DELETE** - Always use Where() unless you want to delete all rows
|
||||
|
||||
## Important Safety Notes
|
||||
|
||||
### ⚠️ DELETE Operations
|
||||
|
||||
DELETE without WHERE clause will delete ALL rows in the table:
|
||||
|
||||
```go
|
||||
// ❌ DANGEROUS - Deletes ALL rows!
|
||||
users.User.Delete().Exec(ctx)
|
||||
|
||||
// ✅ SAFE - Deletes specific rows
|
||||
users.User.Delete().Where(users.ID.Eq("user-id")).Exec(ctx)
|
||||
```
|
||||
|
||||
### ⚠️ UPDATE Operations
|
||||
|
||||
UPDATE requires at least one Set() call:
|
||||
|
||||
```go
|
||||
// ❌ ERROR - No columns to update
|
||||
users.User.Update().Where(users.ID.Eq(1)).Exec(ctx)
|
||||
// Returns: "update query has no columns to update"
|
||||
|
||||
// ✅ CORRECT
|
||||
users.User.Update().
|
||||
Set(users.Name, "New Name").
|
||||
Where(users.ID.Eq(1)).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### ⚠️ Query Timeouts
|
||||
|
||||
Always use context with timeout to prevent hanging queries:
|
||||
|
||||
```go
|
||||
// ✅ RECOMMENDED
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := users.User.Select(users.Email).
|
||||
Where(users.ID.Eq("some-id")).
|
||||
First(ctx, &email)
|
||||
```
|
||||
|
||||
### ⚠️ Connection String Security
|
||||
|
||||
Never log or expose database connection strings as they contain credentials. The library does not sanitize connection strings in error messages.
|
||||
|
||||
## Performance
|
||||
|
||||
**pgm** is designed for performance:
|
||||
|
||||
- Zero reflection overhead
|
||||
- Efficient string building with sync.Pool
|
||||
- Leverages pgx's high-performance connection pooling
|
||||
- Minimal allocations in query building
|
||||
- Direct scanning into your types
|
||||
|
||||
## Requirements
|
||||
|
||||
- Go 1.20 or higher
|
||||
- PostgreSQL 12 or higher
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Author
|
||||
|
||||
|
||||
**Ankit Patial** - [Patial Tech](https://code.patial.tech)
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Built on top of the excellent [jackc/pgx](https://github.com/jackc/pgx) library.
|
||||
|
||||
---
|
||||
|
||||
**Made with ❤️ for the Go community**
|
||||
|
||||
@@ -14,6 +14,10 @@ import (
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// version can be set at build time using:
|
||||
// go build -ldflags "-X main.version=v1.2.3" ./cmd
|
||||
var version = "dev"
|
||||
|
||||
func generate(scheamPath, outDir string) error {
|
||||
// read schame.sql
|
||||
f, err := os.ReadFile(scheamPath)
|
||||
@@ -29,14 +33,16 @@ func generate(scheamPath, outDir string) error {
|
||||
|
||||
// Output dir, create if not exists.
|
||||
if _, err := os.Stat(outDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(outDir, 0740); err != nil {
|
||||
if err := os.MkdirAll(outDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// schema.go will hold all tables info
|
||||
var sb strings.Builder
|
||||
sb.WriteString("// Code generated by code.patial.tech/go/pgm/cmd\n// DO NOT EDIT.\n\n")
|
||||
sb.WriteString(
|
||||
fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s DO NOT EDIT.\n\n", GetVersionString()),
|
||||
)
|
||||
sb.WriteString(fmt.Sprintf("package %s \n", filepath.Base(outDir)))
|
||||
sb.WriteString(`
|
||||
import "code.patial.tech/go/pgm"
|
||||
@@ -70,14 +76,24 @@ func generate(scheamPath, outDir string) error {
|
||||
sb.WriteString("}\n")
|
||||
}
|
||||
modalDir = strings.ToLower(name)
|
||||
os.Mkdir(filepath.Join(outDir, modalDir), 0740)
|
||||
os.Mkdir(filepath.Join(outDir, modalDir), 0755)
|
||||
|
||||
if err = writeColFile(t.Table, t.Columns, filepath.Join(outDir, modalDir), caser); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(")")
|
||||
|
||||
sb.WriteString(`
|
||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||
t := pgm.Table{
|
||||
Name: tblName,
|
||||
DerivedTable: fromQry,
|
||||
}
|
||||
return t
|
||||
}`)
|
||||
|
||||
// Format code before saving
|
||||
code, err := formatGoCode(sb.String())
|
||||
if err != nil {
|
||||
@@ -85,16 +101,20 @@ func generate(scheamPath, outDir string) error {
|
||||
}
|
||||
|
||||
// Save file to disk
|
||||
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0640)
|
||||
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0644)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Caser) error {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("// Code generated by db-gen. DO NOT EDIT.\n\n")
|
||||
sb.WriteString(
|
||||
fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s DO NOT EDIT.\n\n", GetVersionString()),
|
||||
)
|
||||
sb.WriteString(fmt.Sprintf("package %s\n\n", filepath.Base(outDir)))
|
||||
sb.WriteString(fmt.Sprintf("import %q\n\n", "code.patial.tech/go/pgm"))
|
||||
sb.WriteString("const (")
|
||||
sb.WriteString("\n // All fields in table " + tblName)
|
||||
sb.WriteString(fmt.Sprintf("\n All pgm.Field = %q", tblName+".*"))
|
||||
var name string
|
||||
for _, c := range cols {
|
||||
name = strings.ReplaceAll(c.Name, "_", " ")
|
||||
@@ -117,7 +137,7 @@ func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Cas
|
||||
return err
|
||||
}
|
||||
// Save file to disk.
|
||||
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0640)
|
||||
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0644)
|
||||
}
|
||||
|
||||
// pluralToSingular converts plural table names to singular forms
|
||||
|
||||
21
cmd/main.go
21
cmd/main.go
@@ -9,29 +9,42 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
const usageTxt = `Please provide output director and input schema.
|
||||
var (
|
||||
showVersion bool
|
||||
)
|
||||
|
||||
const usageTxt = `Please provide output directory and input schema.
|
||||
Example:
|
||||
pgm/cmd -o ./db ./db/schema.sql
|
||||
pgm -o ./db ./schema.sql
|
||||
|
||||
`
|
||||
|
||||
func main() {
|
||||
var outDir string
|
||||
flag.StringVar(&outDir, "o", "", "-o as output directory path")
|
||||
flag.BoolVar(&showVersion, "version", false, "show version information")
|
||||
flag.Parse()
|
||||
|
||||
// Handle version flag
|
||||
if showVersion {
|
||||
fmt.Printf("pgm %s\n", GetVersionString())
|
||||
fmt.Println("PostgreSQL Query Mapper - Schema code generator")
|
||||
fmt.Println("https://code.patial.tech/go/pgm")
|
||||
return
|
||||
}
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Print(usageTxt)
|
||||
return
|
||||
}
|
||||
|
||||
if outDir == "" {
|
||||
println("missing, -o output directory path")
|
||||
fmt.Fprintln(os.Stderr, "Error: missing output directory path (-o flag required)")
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := generate(os.Args[3], outDir); err != nil {
|
||||
println(err.Error())
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
15
cmd/parse.go
15
cmd/parse.go
@@ -1,5 +1,20 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
//
|
||||
// SQL Parser Limitations:
|
||||
// This is a simple regex-based SQL parser with the following known limitations:
|
||||
// - No support for multi-line comments /* */
|
||||
// - May struggle with complex data types (e.g., arrays, JSON, JSONB)
|
||||
// - No handling of quoted identifiers with special characters
|
||||
// - Advanced features like PARTITION BY, INHERITS not supported
|
||||
// - Does not handle all PostgreSQL-specific syntax
|
||||
// - Constraints like CHECK, EXCLUDE not parsed
|
||||
//
|
||||
// For complex schemas, consider:
|
||||
// 1. Simplifying your schema for generation
|
||||
// 2. Using multiple simple CREATE TABLE statements
|
||||
// 3. Contributing a more robust parser implementation
|
||||
// 4. Using a proper SQL parser library
|
||||
|
||||
package main
|
||||
|
||||
|
||||
67
cmd/version.go
Normal file
67
cmd/version.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Version returns the version of the pgm CLI tool.
|
||||
// It tries to detect the version in the following order:
|
||||
// 1. Build-time ldflags (set via: go build -ldflags "-X main.version=v1.2.3")
|
||||
// 2. VCS information from build metadata (git tag/commit)
|
||||
// 3. Falls back to "dev" if no version information is available
|
||||
func Version() string {
|
||||
// If version was set at build time via ldflags
|
||||
if version != "" && version != "dev" {
|
||||
return version
|
||||
}
|
||||
|
||||
// Try to get version from build info (Go 1.18+)
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
// Check for version in main module
|
||||
if info.Main.Version != "" && info.Main.Version != "(devel)" {
|
||||
return info.Main.Version
|
||||
}
|
||||
|
||||
// Try to extract from VCS information
|
||||
var revision, modified string
|
||||
for _, setting := range info.Settings {
|
||||
switch setting.Key {
|
||||
case "vcs.revision":
|
||||
revision = setting.Value
|
||||
case "vcs.modified":
|
||||
modified = setting.Value
|
||||
}
|
||||
}
|
||||
|
||||
// If we have a git revision
|
||||
if revision != "" {
|
||||
// Shorten commit hash to 7 characters
|
||||
if len(revision) > 7 {
|
||||
revision = revision[:7]
|
||||
}
|
||||
|
||||
// Add -dirty suffix if modified
|
||||
if modified == "true" {
|
||||
return "dev-" + revision + "-dirty"
|
||||
}
|
||||
return "dev-" + revision
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to dev
|
||||
return "dev"
|
||||
}
|
||||
|
||||
// GetVersionString returns a formatted version string for display
|
||||
func GetVersionString() string {
|
||||
v := Version()
|
||||
|
||||
// Clean up version string for display
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
|
||||
return "v" + v
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package example
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm"
|
||||
"code.patial.tech/go/pgm/example/db"
|
||||
"code.patial.tech/go/pgm/example/db/branchuser"
|
||||
"code.patial.tech/go/pgm/example/db/employee"
|
||||
"code.patial.tech/go/pgm/example/db/user"
|
||||
"code.patial.tech/go/pgm/example/db/usersession"
|
||||
)
|
||||
|
||||
func TestQryBuilder2(t *testing.T) {
|
||||
got := db.User.Debug().Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
Join(db.BranchUser, user.ID, branchuser.UserID).
|
||||
Where(
|
||||
user.ID.Eq(1),
|
||||
pgm.Or(
|
||||
user.StatusID.Eq(2),
|
||||
user.UpdatedAt.Eq(3),
|
||||
),
|
||||
user.MfaKind.Eq(4),
|
||||
pgm.Or(
|
||||
user.FirstName.Eq(5),
|
||||
user.MiddleName.Eq(6),
|
||||
),
|
||||
).
|
||||
Where(
|
||||
user.LastName.NEq(7),
|
||||
user.Phone.Like("%123%"),
|
||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||
).
|
||||
Limit(10).
|
||||
Offset(100).
|
||||
String()
|
||||
|
||||
expected := "SELECT users.email, users.first_name FROM users JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||
" JOIN branch_users ON users.id = branch_users.user_id WHERE users.id = $1 AND (users.status_id = $2 OR users.updated_at = $3)" +
|
||||
" AND users.mfa_kind = $4 AND (users.first_name = $5 OR users.middle_name = $6) AND users.last_name != $7 AND users.phone" +
|
||||
" LIKE $8 AND users.email NOT IN(SELECT users.id FROM users WHERE users.id = $9) LIMIT 10 OFFSET 100"
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithHaving(t *testing.T) {
|
||||
expected := "SELECT employees.department, AVG(employees.salary), COUNT(employees.id)" +
|
||||
" FROM employees GROUP BY employees.department HAVING AVG(employees.salary) > $1 AND COUNT(employees.id) > $2"
|
||||
got := db.Employee.
|
||||
Select(employee.Department, employee.Salary.Avg(), employee.ID.Count()).
|
||||
GroupBy(employee.Department).
|
||||
Having(employee.Salary.Avg().Gt(50000), employee.ID.Count().Gt(5)).
|
||||
String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSelect-12 668817 1753 ns/op 4442 B/op 59 allocs/op
|
||||
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
Join(db.BranchUser, user.ID, branchuser.UserID).
|
||||
Where(
|
||||
user.ID.Eq(1),
|
||||
pgm.Or(
|
||||
user.StatusID.Eq(2),
|
||||
user.UpdatedAt.Eq(3),
|
||||
),
|
||||
user.MfaKind.Eq(4),
|
||||
pgm.Or(
|
||||
user.FirstName.Eq(5),
|
||||
user.MiddleName.Eq(6),
|
||||
),
|
||||
).
|
||||
Where(
|
||||
user.LastName.NEq(7),
|
||||
user.Phone.Like("%123%"),
|
||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||
).
|
||||
Limit(10).
|
||||
Offset(100).
|
||||
String()
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package example
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm"
|
||||
"code.patial.tech/go/pgm/example/db"
|
||||
"code.patial.tech/go/pgm/example/db/user"
|
||||
)
|
||||
|
||||
func TestUpdateQuery(t *testing.T) {
|
||||
got := db.User.Update().
|
||||
Set(user.FirstName, "ankit").
|
||||
Set(user.MiddleName, "singh").
|
||||
Set(user.LastName, "patial").
|
||||
Where(
|
||||
user.Email.Eq("aa@aa.com"),
|
||||
).
|
||||
Where(
|
||||
user.StatusID.NEq(1),
|
||||
).
|
||||
String()
|
||||
|
||||
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
||||
if got != expected {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSetMap(t *testing.T) {
|
||||
got := db.User.Update().
|
||||
SetMap(map[pgm.Field]any{
|
||||
user.FirstName: "ankit",
|
||||
user.MiddleName: "singh",
|
||||
user.LastName: "patial",
|
||||
}).
|
||||
Where(
|
||||
user.Email.Eq("aa@aa.com"),
|
||||
).
|
||||
Where(
|
||||
user.StatusID.NEq(1),
|
||||
).
|
||||
String()
|
||||
|
||||
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
||||
if got != expected {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op
|
||||
func BenchmarkUpdateQuery(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Update().
|
||||
Set(user.FirstName, "ankit").
|
||||
Set(user.MiddleName, "singh").
|
||||
Set(user.LastName, "patial").
|
||||
Where(
|
||||
user.Email.Eq("aa@aa.com"),
|
||||
).
|
||||
Where(
|
||||
user.StatusID.NEq(1),
|
||||
).
|
||||
String()
|
||||
}
|
||||
}
|
||||
13
go.mod
13
go.mod
@@ -3,19 +3,14 @@ module code.patial.tech/go/pgm
|
||||
go 1.24.5
|
||||
|
||||
require (
|
||||
github.com/jackc/pgx v3.6.2+incompatible
|
||||
golang.org/x/text v0.27.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
golang.org/x/text v0.31.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jackc/pgx/v5 v5.7.5
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
golang.org/x/crypto v0.40.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
)
|
||||
|
||||
19
go.sum
19
go.sum
@@ -1,25 +1,36 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o=
|
||||
github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
|
||||
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
||||
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
255
pgm.go
255
pgm.go
@@ -1,134 +1,197 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
// pgm
|
||||
//
|
||||
// A simple PG string query builder
|
||||
//
|
||||
// Author: Ankit Patial
|
||||
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Table in database
|
||||
type Table struct {
|
||||
Name string
|
||||
PK []string
|
||||
FieldCount uint16
|
||||
debug bool
|
||||
}
|
||||
|
||||
// Debug when set true will print generated query string in stdout
|
||||
func (t Table) Debug() Clause {
|
||||
t.debug = true
|
||||
return t
|
||||
var (
|
||||
poolPGX atomic.Pointer[pgxpool.Pool]
|
||||
ErrConnStringMissing = errors.New("connection string is empty")
|
||||
)
|
||||
|
||||
// Common errors returned by pgm operations
|
||||
var (
|
||||
ErrInitTX = errors.New("failed to init db.tx")
|
||||
ErrCommitTX = errors.New("failed to commit db.tx")
|
||||
ErrNoRows = errors.New("no data found")
|
||||
)
|
||||
|
||||
// Config holds the configuration for initializing the connection pool.
|
||||
// All fields except ConnString are optional and will use pgx defaults if not set.
|
||||
type Config struct {
|
||||
ConnString string
|
||||
MaxConns int32
|
||||
MinConns int32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
}
|
||||
|
||||
// InitPool initializes the connection pool with the provided configuration.
|
||||
// It validates the configuration and panics if invalid.
|
||||
// This function should be called once at application startup.
|
||||
//
|
||||
// Field ==>
|
||||
// Example:
|
||||
//
|
||||
// pgm.InitPool(pgm.Config{
|
||||
// ConnString: "postgres://user:pass@localhost/dbname",
|
||||
// MaxConns: 100,
|
||||
// MinConns: 5,
|
||||
// })
|
||||
func InitPool(conf Config) {
|
||||
if conf.ConnString == "" {
|
||||
panic(ErrConnStringMissing)
|
||||
}
|
||||
|
||||
// Field related to a table
|
||||
type Field string
|
||||
// Validate configuration
|
||||
if conf.MaxConns > 0 && conf.MinConns > 0 && conf.MinConns > conf.MaxConns {
|
||||
panic(fmt.Errorf("MinConns (%d) cannot be greater than MaxConns (%d)", conf.MinConns, conf.MaxConns))
|
||||
}
|
||||
|
||||
func (f Field) Name() string {
|
||||
return strings.Split(string(f), ".")[1]
|
||||
if conf.MaxConns < 0 || conf.MinConns < 0 {
|
||||
panic(errors.New("connection pool configuration cannot have negative values"))
|
||||
}
|
||||
|
||||
cfg, err := pgxpool.ParseConfig(conf.ConnString)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if conf.MaxConns > 0 {
|
||||
cfg.MaxConns = conf.MaxConns // 100
|
||||
}
|
||||
|
||||
if conf.MinConns > 0 {
|
||||
cfg.MinConns = conf.MinConns // 5
|
||||
}
|
||||
|
||||
if conf.MaxConnLifetime > 0 {
|
||||
cfg.MaxConnLifetime = conf.MaxConnLifetime // time.Minute * 10
|
||||
}
|
||||
|
||||
if conf.MaxConnIdleTime > 0 {
|
||||
cfg.MaxConnIdleTime = conf.MaxConnIdleTime // time.Minute * 5
|
||||
}
|
||||
|
||||
p, err := pgxpool.NewWithConfig(context.Background(), cfg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err = p.Ping(context.Background()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
poolPGX.Store(p)
|
||||
}
|
||||
|
||||
func (f Field) String() string {
|
||||
return string(f)
|
||||
// GetPool returns the initialized connection pool instance.
|
||||
// It panics with a descriptive message if InitPool() has not been called.
|
||||
// This is a fail-fast approach to catch programming errors early.
|
||||
func GetPool() *pgxpool.Pool {
|
||||
p := poolPGX.Load()
|
||||
if p == nil {
|
||||
panic("pgm: connection pool not initialized, call InitPool() first")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Count fn wrapping of field
|
||||
func (f Field) Count() Field {
|
||||
return Field("COUNT(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Avg fn wrapping of field
|
||||
func (f Field) Avg() Field {
|
||||
return Field("AVG(" + f.String() + ")")
|
||||
}
|
||||
|
||||
func (f Field) Eq(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
// EqualFold will user LOWER() for comparision
|
||||
func (f Field) EqFold(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) NEq(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Gt(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Gte(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Like(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
||||
}
|
||||
|
||||
func (f Field) LikeFold(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
// ILIKE is case-insensitive
|
||||
func (f Field) ILike(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) NotIn(val ...any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " NOT IN($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: qry, op: " NOT IN($)", action: CondActionSubQuery}
|
||||
// ClosePool closes the connection pool gracefully.
|
||||
// Should be called during application shutdown.
|
||||
func ClosePool() {
|
||||
if p := poolPGX.Load(); p != nil {
|
||||
p.Close()
|
||||
poolPGX.Store(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// BeginTx begins a new database transaction from the connection pool.
|
||||
// Returns an error if the transaction cannot be started.
|
||||
// Remember to commit or rollback the transaction when done.
|
||||
//
|
||||
// Helper func ==>
|
||||
// Example:
|
||||
//
|
||||
// tx, err := pgm.BeginTx(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer tx.Rollback(ctx) // rollback on error
|
||||
//
|
||||
// // ... do work ...
|
||||
//
|
||||
// return tx.Commit(ctx)
|
||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||
tx, err := poolPGX.Load().Begin(ctx)
|
||||
if err != nil {
|
||||
slog.Error("failed to begin transaction", "error", err)
|
||||
return nil, fmt.Errorf("failed to open db tx: %w", err)
|
||||
}
|
||||
|
||||
// PgTime as in UTC
|
||||
func PgTime(t time.Time) pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: t, Valid: true}
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func PgTimeNow() pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||
}
|
||||
|
||||
// IsNotFound error check
|
||||
// IsNotFound checks if an error is a "no rows" error from pgx.
|
||||
// Returns true if the error indicates no rows were found in a query result.
|
||||
func IsNotFound(err error) bool {
|
||||
return errors.Is(err, pgx.ErrNoRows)
|
||||
}
|
||||
|
||||
func ConcatWs(sep string, fields ...Field) string {
|
||||
return "concat_ws('" + sep + "'," + joinFileds(fields) + ")"
|
||||
// PgTime converts a Go time.Time to PostgreSQL timestamptz type.
|
||||
// The time is stored as-is (preserves timezone information).
|
||||
func PgTime(t time.Time) pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: t, Valid: true}
|
||||
}
|
||||
|
||||
func StringAgg(exp, sep string) string {
|
||||
return "string_agg(" + exp + ",'" + sep + "')"
|
||||
// PgTimeNow returns the current time as PostgreSQL timestamptz type.
|
||||
func PgTimeNow() pgtype.Timestamptz {
|
||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||
}
|
||||
|
||||
func StringAggCast(exp, sep string) string {
|
||||
return "string_agg(cast(" + exp + " as varchar),'" + sep + "')"
|
||||
// TsAndQuery converts a text search query to use AND operator between terms.
|
||||
// Example: "hello world" becomes "hello & world"
|
||||
func TsAndQuery(q string) string {
|
||||
return strings.Join(strings.Fields(q), " & ")
|
||||
}
|
||||
|
||||
// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching.
|
||||
// Example: "hello world" becomes "hello:* & world:*"
|
||||
func TsPrefixAndQuery(q string) string {
|
||||
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
||||
}
|
||||
|
||||
// TsOrQuery converts a text search query to use OR operator between terms.
|
||||
// Example: "hello world" becomes "hello | world"
|
||||
func TsOrQuery(q string) string {
|
||||
return strings.Join(strings.Fields(q), " | ")
|
||||
}
|
||||
|
||||
// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching.
|
||||
// Example: "hello world" becomes "hello:* | world:*"
|
||||
func TsPrefixOrQuery(q string) string {
|
||||
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
||||
}
|
||||
|
||||
func fieldsWithSufix(v, sufix string) []string {
|
||||
fields := strings.Fields(v)
|
||||
prefixed := make([]string, len(fields))
|
||||
for i, f := range fields {
|
||||
prefixed[i] = f + sufix
|
||||
}
|
||||
|
||||
return prefixed
|
||||
}
|
||||
|
||||
340
pgm_field.go
Normal file
340
pgm_field.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Field related to a table
|
||||
type Field string
|
||||
|
||||
var (
|
||||
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
|
||||
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
|
||||
validDateTruncLevels = map[string]bool{
|
||||
"microseconds": true,
|
||||
"milliseconds": true,
|
||||
"second": true,
|
||||
"minute": true,
|
||||
"hour": true,
|
||||
"day": true,
|
||||
"week": true,
|
||||
"month": true,
|
||||
"quarter": true,
|
||||
"year": true,
|
||||
"decade": true,
|
||||
"century": true,
|
||||
"millennium": true,
|
||||
}
|
||||
)
|
||||
|
||||
// validateSQLIdentifier checks if a string is a valid SQL identifier
|
||||
func validateSQLIdentifier(s string) error {
|
||||
if !sqlIdentifierRegex.MatchString(s) {
|
||||
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// escapeSQLString escapes single quotes in a string for SQL
|
||||
func escapeSQLString(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
func (f Field) Name() string {
|
||||
s := string(f)
|
||||
idx := strings.LastIndexByte(s, '.')
|
||||
if idx == -1 {
|
||||
return s // Return as-is if no dot
|
||||
}
|
||||
return s[idx+1:] // Return part after last dot
|
||||
}
|
||||
|
||||
func (f Field) String() string {
|
||||
return string(f)
|
||||
}
|
||||
|
||||
// Count function wrapped field
|
||||
func (f Field) Count() Field {
|
||||
return Field("COUNT(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// ConcatWs creates a CONCAT_WS SQL function.
|
||||
// SECURITY: The sep parameter should only be a constant string, not user input.
|
||||
// Single quotes in sep will be escaped automatically.
|
||||
func ConcatWs(sep string, fields ...Field) Field {
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
|
||||
}
|
||||
|
||||
// StringAgg creates a STRING_AGG SQL function.
|
||||
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAgg(field Field, sep string) Field {
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAggCast(field Field, sep string) Field {
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringEscape will wrap field with:
|
||||
//
|
||||
// COALESCE(field, ”)
|
||||
func (f Field) StringEscape() Field {
|
||||
return Field("COALESCE(" + f.String() + ", '')")
|
||||
}
|
||||
|
||||
// NumberEscape will wrap field with:
|
||||
//
|
||||
// COALESCE(field, 0)
|
||||
func (f Field) NumberEscape() Field {
|
||||
return Field("COALESCE(" + f.String() + ", 0)")
|
||||
}
|
||||
|
||||
// BooleanEscape will wrap field with:
|
||||
//
|
||||
// COALESCE(field, FALSE)
|
||||
func (f Field) BooleanEscape() Field {
|
||||
return Field("COALESCE(" + f.String() + ", FALSE)")
|
||||
}
|
||||
|
||||
// Avg function wrapped field
|
||||
func (f Field) Avg() Field {
|
||||
return Field("AVG(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Sum function wrapped field
|
||||
func (f Field) Sum() Field {
|
||||
return Field("SUM(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Max function wrapped field
|
||||
func (f Field) Max() Field {
|
||||
return Field("MAX(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Min function wrapped field
|
||||
func (f Field) Min() Field {
|
||||
return Field("Min(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Lower function wrapped field
|
||||
func (f Field) Lower() Field {
|
||||
return Field("LOWER(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Upper function wrapped field
|
||||
func (f Field) Upper() Field {
|
||||
return Field("UPPER(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Trim function wrapped field
|
||||
func (f Field) Trim() Field {
|
||||
return Field("TRIM(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// Asc suffixed field, supposed to be used with order by
|
||||
func (f Field) Asc() Field {
|
||||
return Field(f.String() + " ASC")
|
||||
}
|
||||
|
||||
// Desc suffixed field, supposed to be used with order by
|
||||
func (f Field) Desc() Field {
|
||||
return Field(f.String() + " DESC")
|
||||
}
|
||||
|
||||
func (f Field) RowNumber(as string, extraOrderBy ...Field) Field {
|
||||
return rowNumber(&f, nil, true, as, extraOrderBy...)
|
||||
}
|
||||
|
||||
func (f Field) RowNumberDesc(as string, extraOrderBy ...Field) Field {
|
||||
return rowNumber(&f, nil, false, as, extraOrderBy...)
|
||||
}
|
||||
|
||||
// RowNumberPartionBy in ascending order
|
||||
func (f Field) RowNumberPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
|
||||
return rowNumber(&f, &partition, true, as, extraOrderBy...)
|
||||
}
|
||||
|
||||
func (f Field) RowNumberDescPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
|
||||
return rowNumber(&f, &partition, false, as, extraOrderBy...)
|
||||
}
|
||||
|
||||
func rowNumber(f, partition *Field, isAsc bool, as string, extraOrderBy ...Field) Field {
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if as != "" {
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
|
||||
}
|
||||
}
|
||||
var orderBy string
|
||||
if isAsc {
|
||||
orderBy = " ASC"
|
||||
} else {
|
||||
orderBy = " DESC"
|
||||
}
|
||||
|
||||
if as == "" {
|
||||
as = "row_number"
|
||||
}
|
||||
|
||||
col := f.String()
|
||||
|
||||
// Build ORDER BY clause with primary field and extra fields
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
|
||||
sb.WriteString(col)
|
||||
sb.WriteString(orderBy)
|
||||
|
||||
// Add extra ORDER BY fields
|
||||
for _, extra := range extraOrderBy {
|
||||
sb.WriteString(", ")
|
||||
sb.WriteString(extra.String())
|
||||
}
|
||||
|
||||
orderByClause := sb.String()
|
||||
|
||||
if partition != nil {
|
||||
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + orderByClause + ") AS " + as)
|
||||
}
|
||||
|
||||
return Field("ROW_NUMBER() OVER (ORDER BY " + orderByClause + ") AS " + as)
|
||||
}
|
||||
|
||||
func (f Field) IsNull() Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
|
||||
}
|
||||
|
||||
func (f Field) IsNotNull() Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
||||
}
|
||||
|
||||
// DateTrunc will truncate date or timestamp to specified level of precision
|
||||
//
|
||||
// Level values:
|
||||
// - microseconds, milliseconds, second, minute, hour
|
||||
// - day, week (Monday start), month, quarter, year
|
||||
// - decade, century, millennium
|
||||
func (f Field) DateTrunc(level, as string) Field {
|
||||
// Validate level parameter against allowed values
|
||||
if !validDateTruncLevels[strings.ToLower(level)] {
|
||||
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
|
||||
}
|
||||
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
|
||||
}
|
||||
|
||||
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
|
||||
}
|
||||
|
||||
func (f Field) TsRank(fieldName, as string) Field {
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
|
||||
}
|
||||
|
||||
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
||||
}
|
||||
|
||||
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
||||
func (f Field) EqFold(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
// Eq is equal
|
||||
func (f Field) Eq(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) NotEq(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Gt(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Lt(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Gte(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Lte(val any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Like(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
||||
}
|
||||
|
||||
func (f Field) LikeFold(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
// ILIKE is case-insensitive
|
||||
func (f Field) ILike(val string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) Any(val ...any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " = ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
func (f Field) NotAny(val ...any) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: val, op: " != ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||
}
|
||||
|
||||
// NotInSubQuery using ANY
|
||||
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
||||
}
|
||||
|
||||
func (f Field) TsQuery(as string) Conditioner {
|
||||
col := f.String()
|
||||
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
|
||||
}
|
||||
|
||||
func joinFileds(fields []Field) string {
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
for i, f := range fields {
|
||||
if i == 0 {
|
||||
sb.WriteString(f.String())
|
||||
} else {
|
||||
sb.WriteString(", ")
|
||||
sb.WriteString(f.String())
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
382
pgm_field_test.go
Normal file
382
pgm_field_test.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestValidateSQLIdentifier tests SQL identifier validation
|
||||
func TestValidateSQLIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid simple", "column_name", false},
|
||||
{"valid with numbers", "column123", false},
|
||||
{"valid underscore", "_private", false},
|
||||
{"valid mixed", "my_Column_123", false},
|
||||
{"invalid with space", "column name", true},
|
||||
{"invalid with dash", "column-name", true},
|
||||
{"invalid with dot", "table.column", true},
|
||||
{"invalid with quote", "column'name", true},
|
||||
{"invalid with semicolon", "column;DROP", true},
|
||||
{"invalid starts with number", "123column", true},
|
||||
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
|
||||
{"empty string", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSQLIdentifier(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEscapeSQLString tests SQL string escaping
|
||||
func TestEscapeSQLString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"no quotes", "hello", "hello"},
|
||||
{"single quote", "hello'world", "hello''world"},
|
||||
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
|
||||
{"only quote", "'", "''"},
|
||||
{"empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := escapeSQLString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
|
||||
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sep string
|
||||
fields []Field
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe separator",
|
||||
sep: ", ",
|
||||
fields: []Field{"col1", "col2"},
|
||||
contains: "concat_ws(', ',col1, col2)",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
sep: "', (SELECT password FROM users), '",
|
||||
fields: []Field{"col1"},
|
||||
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
|
||||
},
|
||||
{
|
||||
name: "single quote",
|
||||
sep: "'",
|
||||
fields: []Field{"col1"},
|
||||
contains: "concat_ws('''',col1)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConcatWs(tt.sep, tt.fields...)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
|
||||
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field Field
|
||||
sep string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe parameters",
|
||||
field: Field("column_name"),
|
||||
sep: ", ",
|
||||
contains: "string_agg(column_name,', ')",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes in separator",
|
||||
sep: "'; DROP TABLE users; --",
|
||||
field: Field("col"),
|
||||
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StringAgg(tt.field, tt.sep)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
||||
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||
if string(result) != expected {
|
||||
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDateTruncValidation tests DateTrunc level validation
|
||||
func TestDateTruncValidation(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
validLevels := []string{
|
||||
"microseconds", "milliseconds", "second", "minute", "hour",
|
||||
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
|
||||
}
|
||||
|
||||
for _, level := range validLevels {
|
||||
t.Run("valid_"+level, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
|
||||
}
|
||||
}()
|
||||
result := field.DateTrunc(level, "truncated")
|
||||
if !strings.Contains(string(result), level) {
|
||||
t.Errorf("DateTrunc result should contain level %q", level)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test case-insensitive
|
||||
t.Run("case_insensitive", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("DateTrunc should accept uppercase level")
|
||||
}
|
||||
}()
|
||||
result := field.DateTrunc("MONTH", "truncated")
|
||||
if !strings.Contains(strings.ToLower(string(result)), "month") {
|
||||
t.Errorf("DateTrunc should normalize case")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
|
||||
func TestDateTruncInvalidLevel(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
invalidLevels := []string{
|
||||
"invalid",
|
||||
"'; DROP TABLE users; --",
|
||||
"day; DELETE FROM",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, level := range invalidLevels {
|
||||
t.Run("invalid_"+level, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc(level, "truncated")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
|
||||
func TestDateTruncInvalidAlias(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
invalidAliases := []string{
|
||||
"alias name",
|
||||
"alias-name",
|
||||
"'; DROP TABLE",
|
||||
"123alias",
|
||||
"alias.name",
|
||||
}
|
||||
|
||||
for _, alias := range invalidAliases {
|
||||
t.Run("invalid_alias_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc("day", alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTsRankValidation tests TsRank alias validation
|
||||
func TestTsRankValidation(t *testing.T) {
|
||||
field := Field("search_vector")
|
||||
|
||||
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
|
||||
for _, alias := range validAliases {
|
||||
t.Run("valid_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("TsRank should not panic for valid alias %q", alias)
|
||||
}
|
||||
}()
|
||||
result := field.TsRank("query", alias)
|
||||
if !strings.Contains(string(result), alias) {
|
||||
t.Errorf("TsRank result should contain alias %q", alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
invalidAliases := []string{
|
||||
"'; DROP TABLE",
|
||||
"rank name",
|
||||
"rank-name",
|
||||
"123rank",
|
||||
}
|
||||
|
||||
for _, alias := range invalidAliases {
|
||||
t.Run("invalid_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("TsRank should panic for invalid alias %q", alias)
|
||||
}
|
||||
}()
|
||||
field.TsRank("query", alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRowNumberValidation tests ROW_NUMBER alias validation
|
||||
func TestRowNumberValidation(t *testing.T) {
|
||||
field := Field("id")
|
||||
|
||||
t.Run("valid_alias", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("RowNumber should not panic for valid alias: %v", r)
|
||||
}
|
||||
}()
|
||||
result := field.RowNumber("row_num")
|
||||
if !strings.Contains(string(result), "row_num") {
|
||||
t.Errorf("RowNumber result should contain alias")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_alias", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("RowNumber should panic for invalid alias")
|
||||
}
|
||||
}()
|
||||
field.RowNumber("'; DROP TABLE")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSQLInjectionAttackVectors tests common SQL injection patterns
|
||||
func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||
attacks := []string{
|
||||
"'; DROP TABLE users; --",
|
||||
"' OR '1'='1",
|
||||
"'; DELETE FROM users WHERE '1'='1",
|
||||
"1'; UPDATE users SET password='hacked'; --",
|
||||
"admin'--",
|
||||
"' UNION SELECT * FROM passwords--",
|
||||
}
|
||||
|
||||
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
for _, attack := range attacks {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc("day", attack)
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TsRank_alias_protection", func(t *testing.T) {
|
||||
field := Field("search_vector")
|
||||
for _, attack := range attacks {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("TsRank should prevent SQL injection: %q", attack)
|
||||
}
|
||||
}()
|
||||
field.TsRank("query", attack)
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
|
||||
for _, attack := range attacks {
|
||||
result := ConcatWs(attack, Field("col1"))
|
||||
// Check that quotes are escaped (doubled)
|
||||
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
|
||||
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFieldName tests Field.Name() method
|
||||
func TestFieldName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field Field
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "field with table prefix",
|
||||
field: Field("users.email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field with schema and table prefix",
|
||||
field: Field("public.users.email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field without dot",
|
||||
field: Field("email"),
|
||||
expected: "email",
|
||||
},
|
||||
{
|
||||
name: "field with multiple dots",
|
||||
field: Field("schema.table.column"),
|
||||
expected: "column",
|
||||
},
|
||||
{
|
||||
name: "aggregate function",
|
||||
field: Field("COUNT(users.id)"),
|
||||
expected: "id)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.field.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
86
pgm_table.go
Normal file
86
pgm_table.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package pgm
|
||||
|
||||
// Table in database
|
||||
type Table struct {
|
||||
textSearch *textSearchCTE
|
||||
|
||||
Name string
|
||||
DerivedTable Query
|
||||
PK []string
|
||||
FieldCount uint16
|
||||
debug bool
|
||||
}
|
||||
|
||||
// text search Common Table Expression
|
||||
type textSearchCTE struct {
|
||||
name string
|
||||
value string
|
||||
alias string
|
||||
}
|
||||
|
||||
// Debug when set true will print generated query string in stdout
|
||||
func (t *Table) Debug() Clause {
|
||||
t.debug = true
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Table) Field(f string) Field {
|
||||
return Field(t.Name + "." + f)
|
||||
}
|
||||
|
||||
// Insert table statement
|
||||
func (t *Table) Insert() InsertClause {
|
||||
qb := &insertQry{
|
||||
debug: t.debug,
|
||||
table: t.Name,
|
||||
fields: make([]string, 0, t.FieldCount),
|
||||
vals: make([]string, 0, t.FieldCount),
|
||||
args: make([]any, 0, t.FieldCount),
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
func (t *Table) WithTextSearch(name, alias, textToSearch string) *Table {
|
||||
t.textSearch = &textSearchCTE{name: name, value: textToSearch, alias: alias}
|
||||
return t
|
||||
}
|
||||
|
||||
// Select table statement
|
||||
func (t *Table) Select(field ...Field) SelectClause {
|
||||
qb := &selectQry{
|
||||
debug: t.debug,
|
||||
fields: field,
|
||||
textSearch: t.textSearch,
|
||||
}
|
||||
|
||||
if t.DerivedTable != nil {
|
||||
tName, args := t.DerivedTable.Build(true)
|
||||
qb.table = "(" + tName + ") AS " + t.Name
|
||||
qb.args = args
|
||||
} else {
|
||||
qb.table = t.Name
|
||||
}
|
||||
|
||||
return qb
|
||||
}
|
||||
|
||||
// Update table statement
|
||||
func (t *Table) Update() UpdateClause {
|
||||
qb := &updateQry{
|
||||
debug: t.debug,
|
||||
table: t.Name,
|
||||
cols: make([]string, 0, t.FieldCount),
|
||||
args: make([]any, 0, t.FieldCount),
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
// Delete table statement
|
||||
func (t *Table) Delete() DeleteCluase {
|
||||
qb := &deleteQry{
|
||||
debug: t.debug,
|
||||
table: t.Name,
|
||||
}
|
||||
|
||||
return qb
|
||||
}
|
||||
670
pgm_test.go
Normal file
670
pgm_test.go
Normal file
@@ -0,0 +1,670 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
|
||||
func TestGetPoolNotInitialized(t *testing.T) {
|
||||
// Save current pool state
|
||||
currentPool := poolPGX.Load()
|
||||
|
||||
// Clear the pool to simulate uninitialized state
|
||||
poolPGX.Store(nil)
|
||||
|
||||
// Restore pool after test
|
||||
defer func() {
|
||||
poolPGX.Store(currentPool)
|
||||
}()
|
||||
|
||||
// Verify panic occurs
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("GetPool should panic when pool not initialized")
|
||||
} else {
|
||||
// Check panic message
|
||||
msg, ok := r.(string)
|
||||
if !ok || !strings.Contains(msg, "not initialized") {
|
||||
t.Errorf("Expected panic message about initialization, got: %v", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
GetPool() // Should panic
|
||||
}
|
||||
|
||||
// TestConfigValidation tests that InitPool validates configuration
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantPanic bool
|
||||
panicMsg string
|
||||
}{
|
||||
{
|
||||
name: "empty connection string",
|
||||
config: Config{
|
||||
ConnString: "",
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "connection string is empty",
|
||||
},
|
||||
{
|
||||
name: "MinConns greater than MaxConns",
|
||||
config: Config{
|
||||
ConnString: "postgres://localhost/test",
|
||||
MaxConns: 5,
|
||||
MinConns: 10,
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "MinConns",
|
||||
},
|
||||
{
|
||||
name: "negative MaxConns",
|
||||
config: Config{
|
||||
ConnString: "postgres://localhost/test",
|
||||
MaxConns: -1,
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "negative",
|
||||
},
|
||||
{
|
||||
name: "negative MinConns",
|
||||
config: Config{
|
||||
ConnString: "postgres://localhost/test",
|
||||
MinConns: -1,
|
||||
},
|
||||
wantPanic: true,
|
||||
panicMsg: "negative",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if tt.wantPanic && r == nil {
|
||||
t.Errorf("InitPool should panic for %s", tt.name)
|
||||
}
|
||||
if !tt.wantPanic && r != nil {
|
||||
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
|
||||
}
|
||||
if tt.wantPanic && r != nil {
|
||||
msg := ""
|
||||
switch v := r.(type) {
|
||||
case string:
|
||||
msg = v
|
||||
case error:
|
||||
msg = v.Error()
|
||||
}
|
||||
if !strings.Contains(msg, tt.panicMsg) {
|
||||
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
InitPool(tt.config)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
|
||||
func TestClosePoolIdempotent(t *testing.T) {
|
||||
// Save current pool
|
||||
currentPool := poolPGX.Load()
|
||||
defer func() {
|
||||
poolPGX.Store(currentPool)
|
||||
}()
|
||||
|
||||
// Set to nil
|
||||
poolPGX.Store(nil)
|
||||
|
||||
// Should not panic when called on nil pool
|
||||
ClosePool()
|
||||
ClosePool()
|
||||
ClosePool()
|
||||
|
||||
// Should work fine - no panic expected
|
||||
}
|
||||
|
||||
// TestIsNotFound tests the error checking utility
|
||||
func TestIsNotFound(t *testing.T) {
|
||||
// Test with pgx.ErrNoRows
|
||||
if !IsNotFound(pgx.ErrNoRows) {
|
||||
t.Error("IsNotFound should return true for pgx.ErrNoRows")
|
||||
}
|
||||
|
||||
// Test with other errors
|
||||
otherErr := errors.New("some other error")
|
||||
if IsNotFound(otherErr) {
|
||||
t.Error("IsNotFound should return false for non-ErrNoRows errors")
|
||||
}
|
||||
|
||||
// Test with nil
|
||||
if IsNotFound(nil) {
|
||||
t.Error("IsNotFound should return false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgTime tests PostgreSQL timestamp conversion
|
||||
func TestPgTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
pgTime := PgTime(now)
|
||||
|
||||
if !pgTime.Valid {
|
||||
t.Error("PgTime should return valid timestamp")
|
||||
}
|
||||
|
||||
if !pgTime.Time.Equal(now) {
|
||||
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgTimeNow tests current time conversion
|
||||
func TestPgTimeNow(t *testing.T) {
|
||||
before := time.Now()
|
||||
pgTime := PgTimeNow()
|
||||
after := time.Now()
|
||||
|
||||
if !pgTime.Valid {
|
||||
t.Error("PgTimeNow should return valid timestamp")
|
||||
}
|
||||
|
||||
// Check that time is between before and after
|
||||
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
|
||||
t.Error("PgTimeNow should return current time")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTsQueryFunctions tests full-text search query builders
|
||||
func TestTsQueryFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(string) string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "TsAndQuery",
|
||||
fn: TsAndQuery,
|
||||
input: "hello world",
|
||||
expected: "hello & world",
|
||||
},
|
||||
{
|
||||
name: "TsPrefixAndQuery",
|
||||
fn: TsPrefixAndQuery,
|
||||
input: "hello world",
|
||||
expected: "hello:* & world:*",
|
||||
},
|
||||
{
|
||||
name: "TsOrQuery",
|
||||
fn: TsOrQuery,
|
||||
input: "hello world",
|
||||
expected: "hello | world",
|
||||
},
|
||||
{
|
||||
name: "TsPrefixOrQuery",
|
||||
fn: TsPrefixOrQuery,
|
||||
input: "hello world",
|
||||
expected: "hello:* | world:*",
|
||||
},
|
||||
{
|
||||
name: "TsAndQuery with multiple words",
|
||||
fn: TsAndQuery,
|
||||
input: "one two three",
|
||||
expected: "one & two & three",
|
||||
},
|
||||
{
|
||||
name: "TsOrQuery with multiple words",
|
||||
fn: TsOrQuery,
|
||||
input: "one two three",
|
||||
expected: "one | two | three",
|
||||
},
|
||||
{
|
||||
name: "TsAndQuery with extra spaces",
|
||||
fn: TsAndQuery,
|
||||
input: "hello world",
|
||||
expected: "hello & world",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.fn(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFieldsWithSuffix tests the internal suffix function
|
||||
func TestFieldsWithSuffix(t *testing.T) {
|
||||
result := fieldsWithSufix("hello world", ":*")
|
||||
expected := []string{"hello:*", "world:*"}
|
||||
|
||||
if len(result) != len(expected) {
|
||||
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
|
||||
}
|
||||
|
||||
for i, v := range result {
|
||||
if v != expected[i] {
|
||||
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
|
||||
func TestBeginTxErrorWrapping(t *testing.T) {
|
||||
// This test verifies basic behavior without requiring a database connection
|
||||
// Actual error wrapping would require a failing database connection
|
||||
|
||||
// Skip if no pool initialized (unit test environment)
|
||||
if poolPGX.Load() == nil {
|
||||
t.Skip("Skipping BeginTx test - requires initialized pool")
|
||||
}
|
||||
|
||||
// Just verify the function accepts a context
|
||||
// In a real scenario with a cancelled context, it would return an error
|
||||
ctx := context.Background()
|
||||
|
||||
// We can't fully test without a real DB connection, so just document behavior
|
||||
_, err := BeginTx(ctx)
|
||||
if err != nil {
|
||||
// Error is expected in test environment without valid DB
|
||||
t.Logf("BeginTx returned error (expected in test env): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextCancellation tests that cancelled context is handled
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
// Create an already-cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
// Note: This test would require an actual database connection to verify
|
||||
// that the context cancellation is properly propagated. We're testing
|
||||
// that the context is accepted by the function signature.
|
||||
|
||||
// The actual behavior would be tested in integration tests
|
||||
// For now, we just verify the function doesn't panic with cancelled context
|
||||
|
||||
// Skip if no pool initialized (unit test environment)
|
||||
if poolPGX.Load() == nil {
|
||||
t.Skip("Skipping context cancellation test - requires initialized pool")
|
||||
}
|
||||
|
||||
// This would return an error about cancelled context in real scenario
|
||||
_, err := BeginTx(ctx)
|
||||
if err == nil {
|
||||
t.Log("BeginTx with cancelled context - error expected in real scenario")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorTypes verifies exported error variables
|
||||
func TestErrorTypes(t *testing.T) {
|
||||
if ErrConnStringMissing == nil {
|
||||
t.Error("ErrConnStringMissing should be defined")
|
||||
}
|
||||
|
||||
if ErrInitTX == nil {
|
||||
t.Error("ErrInitTX should be defined")
|
||||
}
|
||||
|
||||
if ErrCommitTX == nil {
|
||||
t.Error("ErrCommitTX should be defined")
|
||||
}
|
||||
|
||||
if ErrNoRows == nil {
|
||||
t.Error("ErrNoRows should be defined")
|
||||
}
|
||||
|
||||
// Check error messages are descriptive
|
||||
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
|
||||
t.Error("ErrConnStringMissing should mention connection string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringPoolGetPut tests the string builder pool
|
||||
func TestStringPoolGetPut(t *testing.T) {
|
||||
// Get a string builder from pool
|
||||
sb1 := getSB()
|
||||
if sb1 == nil {
|
||||
t.Fatal("getSB should return non-nil string builder")
|
||||
}
|
||||
|
||||
// Use it
|
||||
sb1.WriteString("test")
|
||||
if sb1.String() != "test" {
|
||||
t.Error("String builder should work normally")
|
||||
}
|
||||
|
||||
// Put it back
|
||||
putSB(sb1)
|
||||
|
||||
// Get another one - should be reset
|
||||
sb2 := getSB()
|
||||
if sb2 == nil {
|
||||
t.Fatal("getSB should return non-nil string builder")
|
||||
}
|
||||
|
||||
// Should be empty after reset
|
||||
if sb2.Len() != 0 {
|
||||
t.Error("String builder from pool should be reset")
|
||||
}
|
||||
|
||||
putSB(sb2)
|
||||
}
|
||||
|
||||
// TestConditionActionTypes tests condition action enum
|
||||
func TestConditionActionTypes(t *testing.T) {
|
||||
// Verify enum values are distinct
|
||||
actions := []CondAction{
|
||||
CondActionNothing,
|
||||
CondActionNeedToClose,
|
||||
CondActionSubQuery,
|
||||
}
|
||||
|
||||
seen := make(map[CondAction]bool)
|
||||
for _, action := range actions {
|
||||
if seen[action] {
|
||||
t.Errorf("Duplicate CondAction value: %v", action)
|
||||
}
|
||||
seen[action] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectQueryBuilderBasics tests basic query building
|
||||
func TestSelectQueryBuilderBasics(t *testing.T) {
|
||||
// Create a test table
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
emailField := Field("users.email")
|
||||
|
||||
// Test basic select
|
||||
qry := testTable.Select(idField, emailField)
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "SELECT") {
|
||||
t.Error("Query should contain SELECT")
|
||||
}
|
||||
if !strings.Contains(sql, "users.id") {
|
||||
t.Error("Query should contain users.id field")
|
||||
}
|
||||
if !strings.Contains(sql, "users.email") {
|
||||
t.Error("Query should contain users.email field")
|
||||
}
|
||||
if !strings.Contains(sql, "FROM users") {
|
||||
t.Error("Query should contain FROM users")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
|
||||
func TestWhereConditionAccumulation(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
statusField := Field("users.status")
|
||||
|
||||
// Create a query and add multiple WHERE conditions
|
||||
qry := testTable.Select(idField).
|
||||
Where(idField.Eq(1)).
|
||||
Where(statusField.Eq("active"))
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
// Both conditions should be present
|
||||
if !strings.Contains(sql, "WHERE") {
|
||||
t.Error("Query should contain WHERE clause")
|
||||
}
|
||||
|
||||
// Should have multiple conditions (this demonstrates accumulation)
|
||||
if strings.Count(sql, "$") < 2 {
|
||||
t.Error("Query should have multiple parameterized values")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInsertQueryBuilder tests insert query building
|
||||
func TestInsertQueryBuilder(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
|
||||
qry := testTable.Insert()
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "INSERT INTO users") {
|
||||
t.Error("Query should contain INSERT INTO users")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateQueryBuilder tests update query building
|
||||
func TestUpdateQueryBuilder(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
|
||||
qry := testTable.Update().
|
||||
Where(idField.Eq(1))
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "UPDATE users") {
|
||||
t.Error("Query should contain UPDATE users")
|
||||
}
|
||||
if !strings.Contains(sql, "WHERE") {
|
||||
t.Error("Query should contain WHERE clause")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteQueryBuilder tests delete query building
|
||||
func TestDeleteQueryBuilder(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
|
||||
qry := testTable.Delete().Where(idField.Eq(1))
|
||||
|
||||
sql := qry.String()
|
||||
|
||||
if !strings.Contains(sql, "DELETE FROM users") {
|
||||
t.Error("Query should contain DELETE FROM users")
|
||||
}
|
||||
if !strings.Contains(sql, "WHERE") {
|
||||
t.Error("Query should contain WHERE clause")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFieldConditioners tests various field condition methods
|
||||
func TestFieldConditioners(t *testing.T) {
|
||||
field := Field("users.age")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition Conditioner
|
||||
checkFunc func(string) bool
|
||||
}{
|
||||
{
|
||||
name: "Eq",
|
||||
condition: field.Eq(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
|
||||
},
|
||||
{
|
||||
name: "NotEq",
|
||||
condition: field.NotEq(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
|
||||
},
|
||||
{
|
||||
name: "Gt",
|
||||
condition: field.Gt(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
|
||||
},
|
||||
{
|
||||
name: "Lt",
|
||||
condition: field.Lt(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
|
||||
},
|
||||
{
|
||||
name: "Gte",
|
||||
condition: field.Gte(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
|
||||
},
|
||||
{
|
||||
name: "Lte",
|
||||
condition: field.Lte(25),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
|
||||
},
|
||||
{
|
||||
name: "IsNull",
|
||||
condition: field.IsNull(),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
|
||||
},
|
||||
{
|
||||
name: "IsNotNull",
|
||||
condition: field.IsNotNull(),
|
||||
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Build a simple query to see the condition
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
qry := testTable.Select(field).Where(tt.condition)
|
||||
sql := qry.String()
|
||||
|
||||
if !tt.checkFunc(sql) {
|
||||
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFieldFunctions tests SQL function wrappers
|
||||
func TestFieldFunctions(t *testing.T) {
|
||||
field := Field("users.count")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
result Field
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "Count",
|
||||
result: field.Count(),
|
||||
contains: "COUNT(",
|
||||
},
|
||||
{
|
||||
name: "Sum",
|
||||
result: field.Sum(),
|
||||
contains: "SUM(",
|
||||
},
|
||||
{
|
||||
name: "Avg",
|
||||
result: field.Avg(),
|
||||
contains: "AVG(",
|
||||
},
|
||||
{
|
||||
name: "Max",
|
||||
result: field.Max(),
|
||||
contains: "MAX(",
|
||||
},
|
||||
{
|
||||
name: "Min",
|
||||
result: field.Min(),
|
||||
contains: "Min(",
|
||||
},
|
||||
{
|
||||
name: "Lower",
|
||||
result: field.Lower(),
|
||||
contains: "LOWER(",
|
||||
},
|
||||
{
|
||||
name: "Upper",
|
||||
result: field.Upper(),
|
||||
contains: "UPPER(",
|
||||
},
|
||||
{
|
||||
name: "Trim",
|
||||
result: field.Trim(),
|
||||
contains: "TRIM(",
|
||||
},
|
||||
{
|
||||
name: "Asc",
|
||||
result: field.Asc(),
|
||||
contains: " ASC",
|
||||
},
|
||||
{
|
||||
name: "Desc",
|
||||
result: field.Desc(),
|
||||
contains: " DESC",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resultStr := string(tt.result)
|
||||
if !strings.Contains(resultStr, tt.contains) {
|
||||
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInsertQueryValidation tests that Insert queries validate fields are set
|
||||
func TestInsertQueryValidation(t *testing.T) {
|
||||
testTable := &Table{Name: "users", FieldCount: 10}
|
||||
idField := Field("users.id")
|
||||
|
||||
t.Run("Exec without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert()
|
||||
err := qry.Exec(context.Background())
|
||||
if err == nil {
|
||||
t.Error("Insert.Exec() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExecTx without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert()
|
||||
// We can't test ExecTx without a real transaction, but we can test the error is defined
|
||||
err := qry.ExecTx(context.Background(), nil)
|
||||
if err == nil {
|
||||
t.Error("Insert.ExecTx() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("First without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert().Returning(idField)
|
||||
var id int
|
||||
err := qry.First(context.Background(), &id)
|
||||
if err == nil {
|
||||
t.Error("Insert.First() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FirstTx without fields", func(t *testing.T) {
|
||||
qry := testTable.Insert().Returning(idField)
|
||||
var id int
|
||||
err := qry.FirstTx(context.Background(), nil, &id)
|
||||
if err == nil {
|
||||
t.Error("Insert.FirstTx() should return error when no fields set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||
t.Errorf("Expected error about no fields, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
// Code generated by db-gen. DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package branchuser
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
const (
|
||||
// All fields in table branch_users
|
||||
All pgm.Field = "branch_users.*"
|
||||
// BranchID field has db type "bigint NOT NULL"
|
||||
BranchID pgm.Field = "branch_users.branch_id"
|
||||
// UserID field has db type "bigint NOT NULL"
|
||||
@@ -1,10 +1,12 @@
|
||||
// Code generated by db-gen. DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package comment
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
const (
|
||||
// All fields in table comments
|
||||
All pgm.Field = "comments.*"
|
||||
// ID field has db type "integer NOT NULL"
|
||||
ID pgm.Field = "comments.id"
|
||||
// PostID field has db type "integer NOT NULL"
|
||||
@@ -1,10 +1,12 @@
|
||||
// Code generated by db-gen. DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package employee
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
const (
|
||||
// All fields in table employees
|
||||
All pgm.Field = "employees.*"
|
||||
// ID field has db type "integer NOT NULL"
|
||||
ID pgm.Field = "employees.id"
|
||||
// Name field has db type "var NOT NULL"
|
||||
@@ -1,10 +1,12 @@
|
||||
// Code generated by db-gen. DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package post
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
const (
|
||||
// All fields in table posts
|
||||
All pgm.Field = "posts.*"
|
||||
// ID field has db type "integer NOT NULL"
|
||||
ID pgm.Field = "posts.id"
|
||||
// UserID field has db type "integer NOT NULL"
|
||||
@@ -1,15 +1,22 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd
|
||||
// DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package db
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
var (
|
||||
User = pgm.Table{Name: "users", FieldCount: 11}
|
||||
User = pgm.Table{Name: "users", FieldCount: 12}
|
||||
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
||||
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
||||
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
||||
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
||||
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
||||
)
|
||||
|
||||
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||
t := pgm.Table{
|
||||
Name: tblName,
|
||||
DerivedTable: fromQry,
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
// Code generated by db-gen. DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package user
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
const (
|
||||
// All fields in table users
|
||||
All pgm.Field = "users.*"
|
||||
// ID field has db type "integer NOT NULL"
|
||||
ID pgm.Field = "users.id"
|
||||
// Name field has db type "character varying(255) NOT NULL"
|
||||
@@ -23,6 +25,8 @@ const (
|
||||
StatusID pgm.Field = "users.status_id"
|
||||
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
||||
MfaKind pgm.Field = "users.mfa_kind"
|
||||
// SearchVector field has db type "tsvector"
|
||||
SearchVector pgm.Field = "users.search_vector"
|
||||
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||
CreatedAt pgm.Field = "users.created_at"
|
||||
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||
@@ -1,10 +1,12 @@
|
||||
// Code generated by db-gen. DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package usersession
|
||||
|
||||
import "code.patial.tech/go/pgm"
|
||||
|
||||
const (
|
||||
// All fields in table user_sessions
|
||||
All pgm.Field = "user_sessions.*"
|
||||
// ID field has db type "character varying NOT NULL"
|
||||
ID pgm.Field = "user_sessions.id"
|
||||
// CreatedAt field has db type "timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL"
|
||||
@@ -1,3 +1,3 @@
|
||||
//go:generate go run code.patial.tech/go/pgm/cmd -o ./db ./schema.sql
|
||||
|
||||
package example
|
||||
package playground
|
||||
@@ -1,16 +1,16 @@
|
||||
package example
|
||||
package playground
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm/example/db"
|
||||
"code.patial.tech/go/pgm/example/db/user"
|
||||
"code.patial.tech/go/pgm/playground/db"
|
||||
"code.patial.tech/go/pgm/playground/db/user"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
expected := "DELETE FROM users WHERE users.id = $1 AND users.status_id NOT IN($2)"
|
||||
expected := "DELETE FROM users WHERE users.id = $1 AND users.status_id != ANY($2)"
|
||||
got := db.User.Delete().
|
||||
Where(user.ID.Eq(1), user.StatusID.NotIn(1, 2, 3)).
|
||||
Where(user.ID.Eq(1), user.StatusID.NotAny(1, 2, 3)).
|
||||
String()
|
||||
if got != expected {
|
||||
t.Errorf("got %q, want %q", got, expected)
|
||||
@@ -1,11 +1,11 @@
|
||||
package example
|
||||
package playground
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm"
|
||||
"code.patial.tech/go/pgm/example/db"
|
||||
"code.patial.tech/go/pgm/example/db/user"
|
||||
"code.patial.tech/go/pgm/playground/db"
|
||||
"code.patial.tech/go/pgm/playground/db/user"
|
||||
)
|
||||
|
||||
func TestInsertQuery(t *testing.T) {
|
||||
@@ -17,24 +17,7 @@ func TestInsertQuery(t *testing.T) {
|
||||
Returning(user.ID).
|
||||
String()
|
||||
|
||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
||||
if got != expected {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertSetMap(t *testing.T) {
|
||||
got := db.User.Insert().
|
||||
SetMap(map[pgm.Field]any{
|
||||
user.Email: "aa@aa.com",
|
||||
user.Phone: 8889991234,
|
||||
user.FirstName: "fname",
|
||||
user.LastName: "lname",
|
||||
}).
|
||||
Returning(user.ID).
|
||||
String()
|
||||
|
||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id"
|
||||
if got != expected {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
@@ -53,7 +36,7 @@ func TestInsertQuery2(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkInsertQuery-12 1952412 605.3 ns/op 1144 B/op 18 allocs/op
|
||||
// BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op
|
||||
func BenchmarkInsertQuery(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Insert().
|
||||
@@ -67,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op
|
||||
// BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op
|
||||
func BenchmarkInsertSetMap(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Insert().
|
||||
159
playground/qry_select_test.go
Normal file
159
playground/qry_select_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package playground
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm"
|
||||
"code.patial.tech/go/pgm/playground/db"
|
||||
"code.patial.tech/go/pgm/playground/db/branchuser"
|
||||
"code.patial.tech/go/pgm/playground/db/employee"
|
||||
"code.patial.tech/go/pgm/playground/db/user"
|
||||
"code.patial.tech/go/pgm/playground/db/usersession"
|
||||
)
|
||||
|
||||
func TestQryBuilder2(t *testing.T) {
|
||||
got := db.User.Debug().Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
Join(db.BranchUser, user.ID, branchuser.UserID).
|
||||
Where(
|
||||
user.ID.Eq(1),
|
||||
pgm.Or(
|
||||
user.StatusID.Eq(2),
|
||||
user.UpdatedAt.Eq(3),
|
||||
),
|
||||
user.MfaKind.Eq(4),
|
||||
pgm.Or(
|
||||
user.FirstName.Eq(5),
|
||||
user.MiddleName.Eq(6),
|
||||
),
|
||||
).
|
||||
Where(
|
||||
user.LastName.NotEq(7),
|
||||
user.Phone.Like("%123%"),
|
||||
user.UpdatedAt.IsNotNull(),
|
||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||
).
|
||||
Limit(10).
|
||||
Offset(100).
|
||||
String()
|
||||
|
||||
expected := "SELECT users.email, users.first_name FROM users JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||
" JOIN branch_users ON users.id = branch_users.user_id WHERE users.id = $1 AND (users.status_id = $2 OR users.updated_at = $3)" +
|
||||
" AND users.mfa_kind = $4 AND (users.first_name = $5 OR users.middle_name = $6) AND users.last_name != $7 AND users.phone" +
|
||||
" LIKE $8 AND users.updated_at IS NOT NULL AND users.email != ANY(SELECT users.id FROM users WHERE users.id = $9) LIMIT 10 OFFSET 100"
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithHaving(t *testing.T) {
|
||||
expected := "SELECT employees.department, AVG(employees.salary), COUNT(employees.id)" +
|
||||
" FROM employees GROUP BY employees.department HAVING AVG(employees.salary) > $1 AND COUNT(employees.id) > $2"
|
||||
got := db.Employee.
|
||||
Select(employee.Department, employee.Salary.Avg(), employee.ID.Count()).
|
||||
GroupBy(employee.Department).
|
||||
Having(employee.Salary.Avg().Gt(50000), employee.ID.Count().Gt(5)).
|
||||
String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithJoin(t *testing.T) {
|
||||
got := db.User.Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
LeftJoin(db.BranchUser, user.ID, branchuser.UserID, pgm.Or(branchuser.RoleID.Eq("1"), branchuser.RoleID.Eq("2"))).
|
||||
Where(
|
||||
user.ID.Eq(3),
|
||||
pgm.Or(
|
||||
user.StatusID.Eq(4),
|
||||
user.UpdatedAt.Eq(5),
|
||||
),
|
||||
).
|
||||
Limit(10).
|
||||
Offset(100).
|
||||
String()
|
||||
|
||||
expected := "SELECT users.email, users.first_name " +
|
||||
"FROM users JOIN user_sessions ON users.id = user_sessions.user_id " +
|
||||
"LEFT JOIN branch_users ON users.id = branch_users.user_id AND (branch_users.role_id = $1 OR branch_users.role_id = $2) " +
|
||||
"WHERE users.id = $3 AND (users.status_id = $4 OR users.updated_at = $5) " +
|
||||
"LIMIT 10 OFFSET 100"
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectDerived(t *testing.T) {
|
||||
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||
" ORDER BY t.status_id, t.created_at DESC"
|
||||
|
||||
qry := db.User.
|
||||
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||
Where(user.StatusID.Eq(1))
|
||||
|
||||
tbl := db.DerivedTable("t", qry)
|
||||
got := tbl.
|
||||
Select(tbl.Field("*")).
|
||||
Where(tbl.Field("rn").Lte(5)).
|
||||
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||
String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectTV(t *testing.T) {
|
||||
expected := "WITH ts AS (SELECT to_tsquery('english', $1) AS query)" +
|
||||
" SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, ts.query) AS rank" +
|
||||
" FROM users" +
|
||||
" JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||
" CROSS JOIN ts" +
|
||||
" WHERE users.status_id = $2 AND users.search_vector @@ ts.query" +
|
||||
" ORDER BY rank DESC"
|
||||
|
||||
qry := db.User.
|
||||
WithTextSearch("ts", "query", "text to search").
|
||||
Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("ts.query", "rank")).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("ts.query")).
|
||||
OrderBy(pgm.Field("rank").Desc())
|
||||
|
||||
got := qry.String()
|
||||
|
||||
if expected != got {
|
||||
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Select(user.Email, user.FirstName).
|
||||
Join(db.UserSession, user.ID, usersession.UserID).
|
||||
Join(db.BranchUser, user.ID, branchuser.UserID).
|
||||
Where(
|
||||
user.ID.Eq(1),
|
||||
pgm.Or(
|
||||
user.StatusID.Eq(2),
|
||||
user.UpdatedAt.Eq(3),
|
||||
),
|
||||
user.MfaKind.Eq(4),
|
||||
pgm.Or(
|
||||
user.FirstName.Eq(5),
|
||||
user.MiddleName.Eq(6),
|
||||
),
|
||||
).
|
||||
Where(
|
||||
user.LastName.NotEq(7),
|
||||
user.Phone.Like("%123%"),
|
||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||
).
|
||||
Limit(10).
|
||||
Offset(100).
|
||||
String()
|
||||
}
|
||||
}
|
||||
61
playground/qry_update_test.go
Normal file
61
playground/qry_update_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package playground
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm/playground/db"
|
||||
"code.patial.tech/go/pgm/playground/db/user"
|
||||
)
|
||||
|
||||
func TestUpdateQuery(t *testing.T) {
|
||||
got := db.User.Update().
|
||||
Set(user.FirstName, "ankit").
|
||||
Set(user.MiddleName, "singh").
|
||||
Set(user.LastName, "patial").
|
||||
Where(
|
||||
user.Email.Eq("aa@aa.com"),
|
||||
).
|
||||
Where(
|
||||
user.StatusID.NotEq(1),
|
||||
).
|
||||
String()
|
||||
|
||||
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
||||
if got != expected {
|
||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateQueryValidation(t *testing.T) {
|
||||
// Test that UPDATE without Set() returns error
|
||||
err := db.User.Update().
|
||||
Where(user.Email.Eq("aa@aa.com")).
|
||||
Exec(context.Background())
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when calling Exec() without Set(), got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no columns to update") {
|
||||
t.Errorf("Expected error message to contain 'no columns to update', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op
|
||||
func BenchmarkUpdateQuery(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_ = db.User.Update().
|
||||
Set(user.FirstName, "ankit").
|
||||
Set(user.MiddleName, "singh").
|
||||
Set(user.LastName, "patial").
|
||||
Where(
|
||||
user.Email.Eq("aa@aa.com"),
|
||||
).
|
||||
Where(
|
||||
user.StatusID.NotEq(1),
|
||||
).
|
||||
String()
|
||||
}
|
||||
}
|
||||
@@ -61,6 +61,7 @@ CREATE TABLE public.users (
|
||||
last_name character varying(50) NOT NULL,
|
||||
status_id smallint,
|
||||
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
||||
search_vector tsvector,
|
||||
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
98
pool.go
98
pool.go
@@ -1,98 +0,0 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
poolPGX atomic.Pointer[pgxpool.Pool]
|
||||
poolStringBuilder = sync.Pool{
|
||||
New: func() any {
|
||||
return new(strings.Builder)
|
||||
},
|
||||
}
|
||||
|
||||
ErrInitTX = errors.New("failed to init db.tx")
|
||||
ErrCommitTX = errors.New("failed to commit db.tx")
|
||||
ErrNoRows = errors.New("no data found")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MaxConns int32
|
||||
MinConns int32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
}
|
||||
|
||||
func Init(connString string, conf *Config) {
|
||||
cfg, err := pgxpool.ParseConfig(connString)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if conf != nil {
|
||||
if conf.MaxConns > 0 {
|
||||
cfg.MaxConns = conf.MaxConns // 100
|
||||
}
|
||||
|
||||
if conf.MinConns > 0 {
|
||||
cfg.MinConns = conf.MaxConns // 5
|
||||
}
|
||||
|
||||
if conf.MaxConnLifetime > 0 {
|
||||
cfg.MaxConnLifetime = conf.MaxConnLifetime // time.Minute * 10
|
||||
}
|
||||
|
||||
if conf.MaxConnIdleTime > 0 {
|
||||
cfg.MaxConnIdleTime = conf.MaxConnIdleTime // time.Minute * 5
|
||||
}
|
||||
}
|
||||
|
||||
p, err := pgxpool.NewWithConfig(context.Background(), cfg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err = p.Ping(context.Background()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
poolPGX.Store(p)
|
||||
}
|
||||
|
||||
func GetPool() *pgxpool.Pool {
|
||||
return poolPGX.Load()
|
||||
}
|
||||
|
||||
// get string builder from pool
|
||||
func getSB() *strings.Builder {
|
||||
return poolStringBuilder.Get().(*strings.Builder)
|
||||
}
|
||||
|
||||
// put string builder back to pool
|
||||
func putSB(sb *strings.Builder) {
|
||||
sb.Reset()
|
||||
poolStringBuilder.Put(sb)
|
||||
}
|
||||
|
||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||
tx, err := poolPGX.Load().Begin(ctx)
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
return nil, errors.New("failed to open db tx")
|
||||
}
|
||||
|
||||
return tx, err
|
||||
}
|
||||
200
qry.go
200
qry.go
@@ -7,156 +7,17 @@ import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type (
|
||||
Clause interface {
|
||||
Insert() InsertClause
|
||||
Select(fields ...Field) SelectClause
|
||||
// Insert() InsertSet
|
||||
// Update() UpdateSet
|
||||
// Delete() WhereOrExec
|
||||
}
|
||||
|
||||
SelectClause interface {
|
||||
// Join and Inner Join are same
|
||||
Join(m Table, t1Field, t2Field Field) SelectClause
|
||||
LeftJoin(m Table, t1Field, t2Field Field) SelectClause
|
||||
RightJoin(m Table, t1Field, t2Field Field) SelectClause
|
||||
FullJoin(m Table, t1Field, t2Field Field) SelectClause
|
||||
CrossJoin(m Table) SelectClause
|
||||
WhereClause
|
||||
OrderByClause
|
||||
GroupByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
raw(prefixArgs []any) (string, []any)
|
||||
}
|
||||
|
||||
WhereClause interface {
|
||||
Where(cond ...Conditioner) AfterWhere
|
||||
}
|
||||
|
||||
AfterWhere interface {
|
||||
WhereClause
|
||||
GroupByClause
|
||||
OrderByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
GroupByClause interface {
|
||||
GroupBy(fields ...Field) AfterGroupBy
|
||||
}
|
||||
|
||||
AfterGroupBy interface {
|
||||
HavinClause
|
||||
OrderByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
HavinClause interface {
|
||||
Having(cond ...Conditioner) AfterHaving
|
||||
}
|
||||
|
||||
AfterHaving interface {
|
||||
OrderByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
OrderByClause interface {
|
||||
OrderBy(fields ...Field) AfterOrderBy
|
||||
}
|
||||
|
||||
AfterOrderBy interface {
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
LimitClause interface {
|
||||
Limit(v int) AfterLimit
|
||||
}
|
||||
|
||||
AfterLimit interface {
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
OffsetClause interface {
|
||||
Offset(v int) AfterOffset
|
||||
}
|
||||
|
||||
AfterOffset interface {
|
||||
LimitClause
|
||||
Query
|
||||
}
|
||||
|
||||
Conditioner interface {
|
||||
Condition(args *[]any, idx int) string
|
||||
}
|
||||
|
||||
Insert interface {
|
||||
Set(field Field, val any) InsertClause
|
||||
SetMap(fields map[Field]any) InsertClause
|
||||
}
|
||||
|
||||
InsertClause interface {
|
||||
Insert
|
||||
Returning(field Field) First
|
||||
OnConflict(fields ...Field) Do
|
||||
Execute
|
||||
Stringer
|
||||
}
|
||||
|
||||
Do interface {
|
||||
DoNothing() Execute
|
||||
DoUpdate(fields ...Field) Execute
|
||||
}
|
||||
|
||||
Update interface {
|
||||
Set(field Field, val any) UpdateClause
|
||||
SetMap(fields map[Field]any) UpdateClause
|
||||
}
|
||||
|
||||
UpdateClause interface {
|
||||
Update
|
||||
Where(cond ...Conditioner) WhereOrExec
|
||||
}
|
||||
|
||||
WhereOrExec interface {
|
||||
Where(cond ...Conditioner) WhereOrExec
|
||||
Execute
|
||||
}
|
||||
|
||||
Query interface {
|
||||
First
|
||||
All
|
||||
Stringer
|
||||
}
|
||||
|
||||
First interface {
|
||||
First(ctx context.Context, dest ...any) error
|
||||
FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error
|
||||
Stringer
|
||||
}
|
||||
|
||||
All interface {
|
||||
// Query rows
|
||||
//
|
||||
// don't forget to close() rows
|
||||
All(ctx context.Context, rows RowsCb) error
|
||||
// Query rows
|
||||
//
|
||||
// don't forget to close() rows
|
||||
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
||||
Update() UpdateClause
|
||||
Delete() DeleteCluase
|
||||
}
|
||||
|
||||
Execute interface {
|
||||
@@ -169,26 +30,26 @@ type (
|
||||
String() string
|
||||
}
|
||||
|
||||
RowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
Conditioner interface {
|
||||
Condition(args *[]any, idx int) string
|
||||
}
|
||||
|
||||
RowsCb func(row RowScanner) error
|
||||
)
|
||||
|
||||
func joinFileds(fields []Field) string {
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
for i, f := range fields {
|
||||
if i == 0 {
|
||||
sb.WriteString(f.String())
|
||||
} else {
|
||||
sb.WriteString(", ")
|
||||
sb.WriteString(f.String())
|
||||
}
|
||||
}
|
||||
var sbPool = sync.Pool{
|
||||
New: func() any {
|
||||
return new(strings.Builder)
|
||||
},
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
// get string builder from pool
|
||||
func getSB() *strings.Builder {
|
||||
return sbPool.Get().(*strings.Builder)
|
||||
}
|
||||
|
||||
// put string builder back to pool
|
||||
func putSB(sb *strings.Builder) {
|
||||
sb.Reset()
|
||||
sbPool.Put(sb)
|
||||
}
|
||||
|
||||
func And(cond ...Conditioner) Conditioner {
|
||||
@@ -208,12 +69,16 @@ func (cv *Cond) Condition(args *[]any, argIdx int) string {
|
||||
}
|
||||
|
||||
// 2. normal condition
|
||||
*args = append(*args, cv.Val)
|
||||
var op string
|
||||
if strings.HasSuffix(cv.op, "$") {
|
||||
op = cv.op + strconv.Itoa(argIdx+1)
|
||||
if cv.Val != nil {
|
||||
*args = append(*args, cv.Val)
|
||||
if strings.HasSuffix(cv.op, "$") {
|
||||
op = cv.op + strconv.Itoa(argIdx+1)
|
||||
} else {
|
||||
op = strings.Replace(cv.op, "$", "$"+strconv.Itoa(argIdx+1), 1)
|
||||
}
|
||||
} else {
|
||||
op = strings.Replace(cv.op, "$", "$"+strconv.Itoa(argIdx+1), 1)
|
||||
op = cv.op
|
||||
}
|
||||
|
||||
if cv.action == CondActionNeedToClose {
|
||||
@@ -227,13 +92,14 @@ func (c *CondGroup) Condition(args *[]any, argIdx int) string {
|
||||
defer putSB(sb)
|
||||
|
||||
sb.WriteString("(")
|
||||
currentIdx := argIdx
|
||||
for i, cond := range c.cond {
|
||||
if i == 0 {
|
||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
||||
} else {
|
||||
if i > 0 {
|
||||
sb.WriteString(c.op)
|
||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
||||
}
|
||||
argsLenBefore := len(*args)
|
||||
sb.WriteString(cond.Condition(args, currentIdx))
|
||||
currentIdx += len(*args) - argsLenBefore
|
||||
}
|
||||
sb.WriteString(")")
|
||||
return sb.String()
|
||||
|
||||
@@ -7,6 +7,10 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
DeleteCluase interface {
|
||||
WhereOrExec
|
||||
}
|
||||
|
||||
deleteQry struct {
|
||||
table string
|
||||
condition []Conditioner
|
||||
@@ -15,14 +19,11 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (t *Table) Delete() WhereOrExec {
|
||||
qb := &deleteQry{
|
||||
table: t.Name,
|
||||
debug: t.debug,
|
||||
}
|
||||
|
||||
return qb
|
||||
}
|
||||
// WARNING: DELETE without WHERE clause will delete ALL rows in the table.
|
||||
// Always use Where() to specify conditions unless you intentionally want to delete all rows.
|
||||
// Example:
|
||||
// table.Delete().Where(field.Eq(value)).Exec(ctx) // Safe
|
||||
// table.Delete().Exec(ctx) // Deletes ALL rows!
|
||||
|
||||
func (q *deleteQry) Where(cond ...Conditioner) WhereOrExec {
|
||||
q.condition = append(q.condition, cond...)
|
||||
|
||||
@@ -5,36 +5,40 @@ package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type insertQry struct {
|
||||
returing *string
|
||||
onConflict *string
|
||||
|
||||
table string
|
||||
conflictAction string
|
||||
|
||||
fields []string
|
||||
vals []string
|
||||
args []any
|
||||
debug bool
|
||||
}
|
||||
|
||||
func (t *Table) Insert() Insert {
|
||||
qb := &insertQry{
|
||||
table: t.Name,
|
||||
fields: make([]string, 0, t.FieldCount),
|
||||
vals: make([]string, 0, t.FieldCount),
|
||||
args: make([]any, 0, t.FieldCount),
|
||||
debug: t.debug,
|
||||
type (
|
||||
InsertClause interface {
|
||||
Insert
|
||||
Returning(field Field) First
|
||||
OnConflict(fields ...Field) Do
|
||||
Execute
|
||||
Stringer
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
Insert interface {
|
||||
Set(field Field, val any) InsertClause
|
||||
SetMap(fields map[Field]any) InsertClause
|
||||
}
|
||||
|
||||
insertQry struct {
|
||||
returing *string
|
||||
onConflict *string
|
||||
|
||||
table string
|
||||
conflictAction string
|
||||
|
||||
fields []string
|
||||
vals []string
|
||||
args []any
|
||||
debug bool
|
||||
}
|
||||
)
|
||||
|
||||
func (q *insertQry) Set(field Field, val any) InsertClause {
|
||||
q.fields = append(q.fields, field.Name())
|
||||
@@ -51,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause {
|
||||
}
|
||||
|
||||
func (q *insertQry) Returning(field Field) First {
|
||||
col := field.Name()
|
||||
col := field.String()
|
||||
q.returing = &col
|
||||
return q
|
||||
}
|
||||
@@ -80,21 +84,28 @@ func (q *insertQry) DoNothing() Execute {
|
||||
}
|
||||
|
||||
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
||||
var sb strings.Builder
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
|
||||
sb.WriteString("DO UPDATE SET ")
|
||||
for i, f := range fields {
|
||||
col := f.Name()
|
||||
if i == 0 {
|
||||
fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col)
|
||||
} else {
|
||||
fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col)
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(col)
|
||||
sb.WriteString(" = EXCLUDED.")
|
||||
sb.WriteString(col)
|
||||
}
|
||||
|
||||
q.conflictAction = "DO UPDATE SET " + sb.String()
|
||||
q.conflictAction = sb.String()
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *insertQry) Exec(ctx context.Context) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before Exec()")
|
||||
}
|
||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -103,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before ExecTx()")
|
||||
}
|
||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -112,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||
}
|
||||
|
||||
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before First()")
|
||||
}
|
||||
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||
}
|
||||
|
||||
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||
if len(q.fields) == 0 {
|
||||
return errors.New("insert query has no fields to insert: call Set() before FirstTx()")
|
||||
}
|
||||
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||
}
|
||||
|
||||
|
||||
245
qry_select.go
245
qry_select.go
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -14,7 +15,128 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
SelectClause interface {
|
||||
// Join and Inner Join are same
|
||||
Join(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||
LeftJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||
RightJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||
FullJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||
CrossJoin(m Table) SelectClause
|
||||
WhereClause
|
||||
OrderByClause
|
||||
GroupByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
raw(prefixArgs []any) (string, []any)
|
||||
}
|
||||
|
||||
WhereClause interface {
|
||||
Where(cond ...Conditioner) AfterWhere
|
||||
}
|
||||
|
||||
AfterWhere interface {
|
||||
WhereClause
|
||||
GroupByClause
|
||||
OrderByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
GroupByClause interface {
|
||||
GroupBy(fields ...Field) AfterGroupBy
|
||||
}
|
||||
|
||||
AfterGroupBy interface {
|
||||
HavinClause
|
||||
OrderByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
HavinClause interface {
|
||||
Having(cond ...Conditioner) AfterHaving
|
||||
}
|
||||
|
||||
AfterHaving interface {
|
||||
OrderByClause
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
OrderByClause interface {
|
||||
OrderBy(fields ...Field) AfterOrderBy
|
||||
}
|
||||
|
||||
AfterOrderBy interface {
|
||||
LimitClause
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
LimitClause interface {
|
||||
Limit(v int) AfterLimit
|
||||
}
|
||||
|
||||
AfterLimit interface {
|
||||
OffsetClause
|
||||
Query
|
||||
}
|
||||
|
||||
OffsetClause interface {
|
||||
Offset(v int) AfterOffset
|
||||
}
|
||||
|
||||
AfterOffset interface {
|
||||
LimitClause
|
||||
Query
|
||||
}
|
||||
|
||||
Do interface {
|
||||
DoNothing() Execute
|
||||
DoUpdate(fields ...Field) Execute
|
||||
}
|
||||
|
||||
Query interface {
|
||||
First
|
||||
All
|
||||
Stringer
|
||||
Bulder
|
||||
}
|
||||
|
||||
RowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
RowsCb func(row RowScanner) error
|
||||
|
||||
First interface {
|
||||
First(ctx context.Context, dest ...any) error
|
||||
FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error
|
||||
Stringer
|
||||
}
|
||||
|
||||
All interface {
|
||||
// Query rows
|
||||
//
|
||||
// don't forget to close() rows
|
||||
All(ctx context.Context, rows RowsCb) error
|
||||
// Query rows
|
||||
//
|
||||
// don't forget to close() rows
|
||||
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
||||
}
|
||||
|
||||
Bulder interface {
|
||||
Build(needArgs bool) (qry string, args []any)
|
||||
}
|
||||
|
||||
selectQry struct {
|
||||
textSearch *textSearchCTE
|
||||
|
||||
table string
|
||||
fields []Field
|
||||
args []any
|
||||
@@ -23,9 +145,11 @@ type (
|
||||
groupBy []Field
|
||||
having []Conditioner
|
||||
orderBy []Field
|
||||
limit int
|
||||
offset int
|
||||
debug bool
|
||||
|
||||
limit int
|
||||
offset int
|
||||
|
||||
debug bool
|
||||
}
|
||||
|
||||
CondAction uint8
|
||||
@@ -51,34 +175,47 @@ const (
|
||||
CondActionSubQuery
|
||||
)
|
||||
|
||||
// Select clause
|
||||
func (t Table) Select(field ...Field) SelectClause {
|
||||
qb := &selectQry{
|
||||
table: t.Name,
|
||||
debug: t.debug,
|
||||
fields: field,
|
||||
func (q *selectQry) Join(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||
return q.buildJoin(t, "JOIN", t1Field, t2Field, cond...)
|
||||
}
|
||||
|
||||
func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||
return q.buildJoin(t, "LEFT JOIN", t1Field, t2Field, cond...)
|
||||
}
|
||||
|
||||
func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||
return q.buildJoin(t, "RIGHT JOIN", t1Field, t2Field, cond...)
|
||||
}
|
||||
|
||||
func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||
return q.buildJoin(t, "FULL JOIN", t1Field, t2Field, cond...)
|
||||
}
|
||||
|
||||
func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||
str := joinKW + " " + t.Name + " ON " + t1Field.String() + " = " + t2Field.String()
|
||||
if len(cond) == 0 { // Join with no condition
|
||||
q.join = append(q.join, str)
|
||||
return q
|
||||
}
|
||||
|
||||
return qb
|
||||
}
|
||||
// Join has condition(s)
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
sb.Grow(len(str) * 2)
|
||||
|
||||
func (q *selectQry) Join(t Table, t1Field, t2Field Field) SelectClause {
|
||||
q.join = append(q.join, "JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
||||
return q
|
||||
}
|
||||
sb.WriteString(str)
|
||||
sb.WriteString(" AND ")
|
||||
|
||||
func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field) SelectClause {
|
||||
q.join = append(q.join, "LEFT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
||||
return q
|
||||
}
|
||||
var argIdx int
|
||||
for i, c := range cond {
|
||||
argIdx = len(q.args)
|
||||
if i > 0 {
|
||||
sb.WriteString(" AND ")
|
||||
}
|
||||
sb.WriteString(c.Condition(&q.args, argIdx))
|
||||
}
|
||||
|
||||
func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field) SelectClause {
|
||||
q.join = append(q.join, "RIGHT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field) SelectClause {
|
||||
q.join = append(q.join, "FULL JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
||||
q.join = append(q.join, sb.String())
|
||||
return q
|
||||
}
|
||||
|
||||
@@ -127,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||
|
||||
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
@@ -138,13 +279,21 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors from iteration
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||
rows, err := tx.Query(ctx, q.String(), q.args...)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -154,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors from iteration
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,11 +317,28 @@ func (q *selectQry) raw(prefixArgs []any) (string, []any) {
|
||||
}
|
||||
|
||||
func (q *selectQry) String() string {
|
||||
qry, _ := q.Build(false)
|
||||
if q.debug {
|
||||
fmt.Println("***")
|
||||
fmt.Println(qry)
|
||||
fmt.Printf("%+v\n", q.args)
|
||||
fmt.Println("***")
|
||||
}
|
||||
return qry
|
||||
}
|
||||
|
||||
func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
|
||||
sb := getSB()
|
||||
defer putSB(sb)
|
||||
|
||||
sb.Grow(q.averageLen())
|
||||
|
||||
if q.textSearch != nil {
|
||||
var ts = q.textSearch
|
||||
q.args = slices.Insert(q.args, 0, any(ts.value))
|
||||
sb.WriteString("WITH " + ts.name + " AS (SELECT to_tsquery('english', $1) AS " + ts.alias + ") ")
|
||||
}
|
||||
|
||||
// SELECT
|
||||
sb.WriteString("SELECT ")
|
||||
sb.WriteString(joinFileds(q.fields))
|
||||
@@ -178,6 +349,11 @@ func (q *selectQry) String() string {
|
||||
sb.WriteString(" " + strings.Join(q.join, " "))
|
||||
}
|
||||
|
||||
// Search Query Cross join
|
||||
if q.textSearch != nil {
|
||||
sb.WriteString(" CROSS JOIN " + q.textSearch.name)
|
||||
}
|
||||
|
||||
// WHERE
|
||||
if len(q.where) > 0 {
|
||||
sb.WriteString(" WHERE ")
|
||||
@@ -228,14 +404,19 @@ func (q *selectQry) String() string {
|
||||
sb.WriteString(strconv.Itoa(q.offset))
|
||||
}
|
||||
|
||||
qry := sb.String()
|
||||
qry = sb.String()
|
||||
if q.debug {
|
||||
fmt.Println("***")
|
||||
fmt.Println(qry)
|
||||
fmt.Printf("%+v\n", q.args)
|
||||
fmt.Println("***")
|
||||
}
|
||||
return qry
|
||||
|
||||
if needArgs {
|
||||
args = slices.Clone(q.args)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (q *selectQry) averageLen() int {
|
||||
|
||||
@@ -5,29 +5,37 @@ package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type updateQry struct {
|
||||
table string
|
||||
cols []string
|
||||
condition []Conditioner
|
||||
args []any
|
||||
debug bool
|
||||
}
|
||||
|
||||
func (t *Table) Update() Update {
|
||||
qb := &updateQry{
|
||||
table: t.Name,
|
||||
debug: t.debug,
|
||||
cols: make([]string, 0, t.FieldCount),
|
||||
args: make([]any, 0, t.FieldCount),
|
||||
type (
|
||||
Update interface {
|
||||
Set(field Field, val any) UpdateClause
|
||||
SetMap(fields map[Field]any) UpdateClause
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
UpdateClause interface {
|
||||
Update
|
||||
Where(cond ...Conditioner) WhereOrExec
|
||||
}
|
||||
|
||||
WhereOrExec interface {
|
||||
Where(cond ...Conditioner) WhereOrExec
|
||||
Execute
|
||||
}
|
||||
|
||||
updateQry struct {
|
||||
table string
|
||||
cols []string
|
||||
condition []Conditioner
|
||||
args []any
|
||||
debug bool
|
||||
}
|
||||
)
|
||||
|
||||
func (q *updateQry) Set(field Field, val any) UpdateClause {
|
||||
col := field.Name()
|
||||
@@ -49,6 +57,9 @@ func (q *updateQry) Where(cond ...Conditioner) WhereOrExec {
|
||||
}
|
||||
|
||||
func (q *updateQry) Exec(ctx context.Context) error {
|
||||
if len(q.cols) == 0 {
|
||||
return errors.New("update query has no columns to update: call Set() before Exec()")
|
||||
}
|
||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -58,6 +69,9 @@ func (q *updateQry) Exec(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (q *updateQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||
if len(q.cols) == 0 {
|
||||
return errors.New("update query has no columns to update: call Set() before ExecTx()")
|
||||
}
|
||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user