forked from go/pgm
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d8c22d781 | |||
| c2cf7ff088 | |||
| cc7e6b7b3f | |||
| 48f1d1952e | |||
| 1d9d9d9308 | |||
| 29cddb6389 | |||
| 551e2123bc | |||
| a2b984c342 | |||
| a795c0e8d6 | |||
| bb6a45732f | |||
| 2551e07b3e | |||
| 9837fb1e37 | |||
| 12d6fface6 | |||
| 325103e8ef | |||
| 6f5748d3d3 | |||
| b25f9367ed | |||
| 8750f3ad95 | |||
| ad1faf2056 |
314
CLAUDE.md
Normal file
314
CLAUDE.md
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
**pgm** is a lightweight, type-safe PostgreSQL query builder for Go, built on top of jackc/pgx. It generates Go code from SQL schema files, enabling compile-time safety for database queries without the overhead of traditional ORMs.
|
||||||
|
|
||||||
|
**Core Philosophy:**
|
||||||
|
- Schema defined in SQL (not Go)
|
||||||
|
- Minimal code generation (only table/column definitions)
|
||||||
|
- Users provide their own models (no forced abstractions)
|
||||||
|
- Type-safe query building with fluent API
|
||||||
|
- Zero reflection for maximum performance
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
### Building
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build CLI with version from git tags
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Build with specific version
|
||||||
|
make build VERSION=v1.2.3
|
||||||
|
|
||||||
|
# Install to GOPATH/bin
|
||||||
|
make install
|
||||||
|
|
||||||
|
# Check current version
|
||||||
|
make version
|
||||||
|
pgm -version
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests in playground
|
||||||
|
make test
|
||||||
|
|
||||||
|
# Run benchmarks for SELECT queries
|
||||||
|
make bench-select
|
||||||
|
|
||||||
|
# Run code generator on playground schema
|
||||||
|
make run
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Generation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Generate Go code from SQL schema
|
||||||
|
pgm -o ./db ./schema.sql
|
||||||
|
|
||||||
|
# The generator expects a SQL schema file and outputs:
|
||||||
|
# - One package per table in the output directory
|
||||||
|
# - Each package contains table and column definitions
|
||||||
|
# Example output: db/users/users.go, db/posts/posts.go
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
#### 1. Query Builder System (qry_*.go files)
|
||||||
|
|
||||||
|
The query builder uses a **mutable, stateful design** for conditional query building:
|
||||||
|
|
||||||
|
- **qry_select.go**: SELECT queries with joins, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET
|
||||||
|
- **qry_insert.go**: INSERT queries with RETURNING and UPSERT (ON CONFLICT) support
|
||||||
|
- **qry_update.go**: UPDATE queries with conditional WHERE clauses
|
||||||
|
- **qry_delete.go**: DELETE queries with WHERE conditions
|
||||||
|
|
||||||
|
**Important:** Query builders accumulate state with each method call. They are designed for conditional building within a single query but should NOT be reused across multiple separate queries.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ CORRECT - Conditional building
|
||||||
|
query := users.User.Select(users.ID, users.Email)
|
||||||
|
if nameFilter != "" {
|
||||||
|
query = query.Where(users.Name.Like("%" + nameFilter + "%"))
|
||||||
|
}
|
||||||
|
err := query.First(ctx, &id, &email)
|
||||||
|
|
||||||
|
// ❌ WRONG - Reusing builder across separate queries
|
||||||
|
baseQuery := users.User.Select(users.ID)
|
||||||
|
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1) // Adds ID=1
|
||||||
|
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2) // Has BOTH conditions!
|
||||||
|
```
|
||||||
|
|
||||||
|
Query builders are NOT thread-safe. Each goroutine must create its own query instance.
|
||||||
|
|
||||||
|
#### 2. Field System (pgm_field.go)
|
||||||
|
|
||||||
|
Fields are type-safe column references with:
|
||||||
|
- Comparison operators: `Eq()`, `NotEq()`, `Gt()`, `Lt()`, `Gte()`, `Lte()`
|
||||||
|
- Pattern matching: `Like()`, `ILike()`, `LikeFold()`, `EqFold()`
|
||||||
|
- NULL checks: `IsNull()`, `IsNotNull()`
|
||||||
|
- Array operations: `Any()`, `NotAny()`
|
||||||
|
- Aggregate functions: `Count()`, `Sum()`, `Avg()`, `Min()`, `Max()`
|
||||||
|
- String functions: `Lower()`, `Upper()`, `Trim()`, `StringEscape()`
|
||||||
|
- Special functions: `ConcatWs()`, `StringAgg()`, `DateTrunc()`, `RowNumber()`
|
||||||
|
|
||||||
|
**Security:** Functions like `ConcatWs()`, `StringAgg()`, and `DateTrunc()` validate inputs and escape SQL strings. They use allowlists for parameters like date truncation levels.
|
||||||
|
|
||||||
|
#### 3. Connection Pool (pgm.go)
|
||||||
|
|
||||||
|
Global connection pool using pgxpool with atomic pointer for thread safety:
|
||||||
|
|
||||||
|
- `InitPool(Config)`: Initialize once at startup (panics if called multiple times or with invalid config)
|
||||||
|
- `GetPool()`: Retrieve pool instance (panics if not initialized)
|
||||||
|
- `ClosePool()`: Graceful shutdown
|
||||||
|
- `BeginTx(ctx)`: Start transactions
|
||||||
|
|
||||||
|
The pool is stored in an `atomic.Pointer[pgxpool.Pool]` for lock-free concurrent access.
|
||||||
|
|
||||||
|
#### 4. Code Generator (cmd/)
|
||||||
|
|
||||||
|
Parses SQL schema files and generates Go code:
|
||||||
|
|
||||||
|
- **cmd/main.go**: CLI entry point with version flag support
|
||||||
|
- **cmd/parse.go**: Regex-based SQL parser (has known limitations)
|
||||||
|
- **cmd/generate.go**: Code generation with go/format for proper formatting
|
||||||
|
- **cmd/version.go**: Version string generation from build-time ldflags
|
||||||
|
|
||||||
|
**Parser Limitations:**
|
||||||
|
- No multi-line comments (`/* */`)
|
||||||
|
- Limited support for complex data types (arrays, JSON, JSONB)
|
||||||
|
- No advanced PostgreSQL features (PARTITION BY, INHERITS)
|
||||||
|
- Some constraints (CHECK, EXCLUDE) not parsed
|
||||||
|
|
||||||
|
Generated files include a header comment with version and timestamp:
|
||||||
|
```go
|
||||||
|
// Code generated by code.patial.tech/go/pgm/cmd v1.2.3 on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. Table System (pgm_table.go)
|
||||||
|
|
||||||
|
Minimal table metadata:
|
||||||
|
```go
|
||||||
|
type Table struct {
|
||||||
|
Name string // Table name
|
||||||
|
FieldCount int // Number of columns
|
||||||
|
PK []string // Primary key columns
|
||||||
|
DerivedTable Query // For subqueries/CTEs
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Provides factory methods: `Select()`, `Insert()`, `Update()`, `Delete()`
|
||||||
|
|
||||||
|
### String Builder Pool
|
||||||
|
|
||||||
|
Uses `sync.Pool` for efficient string building (qry.go):
|
||||||
|
```go
|
||||||
|
var sbPool = sync.Pool{
|
||||||
|
New: func() any { return new(strings.Builder) }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
All query builders use `getSB()` / `putSB()` to reduce allocations.
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
Custom errors in pgm.go:
|
||||||
|
- `ErrConnStringMissing`: Connection string validation
|
||||||
|
- `ErrInitTX`: Transaction initialization failure
|
||||||
|
- `ErrCommitTX`: Transaction commit failure
|
||||||
|
- `ErrNoRows`: Wrapper for pgx.ErrNoRows
|
||||||
|
|
||||||
|
Use `pgm.IsNotFound(err)` to check for no rows errors.
|
||||||
|
|
||||||
|
### Generated Code Structure
|
||||||
|
|
||||||
|
For a schema with `users` and `posts` tables:
|
||||||
|
```
|
||||||
|
db/
|
||||||
|
├── schema.go # Table definitions and DerivedTable helper
|
||||||
|
├── user/
|
||||||
|
│ └── users.go # User table columns as constants
|
||||||
|
├── post/
|
||||||
|
│ └── posts.go # Post table columns as constants
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Each table file exports constants like:
|
||||||
|
```go
|
||||||
|
const (
|
||||||
|
All pgm.Field = "users.*"
|
||||||
|
ID pgm.Field = "users.id"
|
||||||
|
Email pgm.Field = "users.email"
|
||||||
|
// ... all columns
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Naming Conventions
|
||||||
|
|
||||||
|
- **Table pluralization**: `pluralToSingular()` in cmd/generate.go handles irregular plurals (people→person, children→child) and common patterns (-ies→-y, -ves→-fe, -es→-e, -s→"")
|
||||||
|
- **Field naming**: Snake_case columns converted to PascalCase (first_name → FirstName)
|
||||||
|
- **ID suffix**: Fields ending in `_id` become `ID` not `Id` (user_id → UserID)
|
||||||
|
|
||||||
|
## Key Implementation Details
|
||||||
|
|
||||||
|
### Query Building Strategy
|
||||||
|
|
||||||
|
1. **Pre-allocation**: Queries estimate final string length via `averageLen()` methods to reduce allocations
|
||||||
|
2. **Conditional accumulation**: All clauses stored in slices, built on demand
|
||||||
|
3. **Parameterized queries**: Uses PostgreSQL numbered parameters (`$1`, `$2`, etc.)
|
||||||
|
4. **Builder methods return interfaces**: Enforces correct method call sequences at compile time
|
||||||
|
|
||||||
|
Example type progression:
|
||||||
|
```go
|
||||||
|
SelectClause → WhereClause → AfterWhere → OrderByClause → Query → First/All
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transaction Handling
|
||||||
|
|
||||||
|
Pattern used throughout:
|
||||||
|
```go
|
||||||
|
tx, err := pgm.BeginTx(ctx)
|
||||||
|
if err != nil { return err }
|
||||||
|
defer tx.Rollback(ctx) // Safe to call after commit
|
||||||
|
|
||||||
|
// ... operations using ExecTx, FirstTx, AllTx methods ...
|
||||||
|
|
||||||
|
return tx.Commit(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
All query execution methods have `Tx` variants that accept `pgx.Tx`.
|
||||||
|
|
||||||
|
### Full-Text Search Helpers
|
||||||
|
|
||||||
|
Helper functions for PostgreSQL tsvector queries:
|
||||||
|
- `TsAndQuery()`: AND operator between terms
|
||||||
|
- `TsPrefixAndQuery()`: AND with prefix matching `:*`
|
||||||
|
- `TsOrQuery()`: OR operator
|
||||||
|
- `TsPrefixOrQuery()`: OR with prefix matching
|
||||||
|
|
||||||
|
Used with tsvector columns via `Field.TsQuery()`.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Tests are in the `playground/` directory:
|
||||||
|
- **playground/schema.sql**: Test database schema
|
||||||
|
- **playground/db/**: Generated code from schema
|
||||||
|
- **playground/*_test.go**: Integration tests for SELECT, INSERT, UPDATE, DELETE
|
||||||
|
- **playground/local_select_test.go**: Additional SELECT test cases
|
||||||
|
|
||||||
|
Tests require a running PostgreSQL instance. The playground uses the generated code to verify the query builder works correctly.
|
||||||
|
|
||||||
|
## Version Management
|
||||||
|
|
||||||
|
Version is injected at build time via ldflags:
|
||||||
|
```bash
|
||||||
|
go build -ldflags "-X main.version=v1.2.3" ./cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
The Makefile automatically extracts version from git tags:
|
||||||
|
```makefile
|
||||||
|
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
|
```
|
||||||
|
|
||||||
|
This version appears in:
|
||||||
|
- CLI `pgm -version` output
|
||||||
|
- Generated file headers
|
||||||
|
- Version string formatting in cmd/version.go
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Connection Pool Initialization
|
||||||
|
Always initialize once at startup and defer cleanup:
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
pgm.InitPool(pgm.Config{
|
||||||
|
ConnString: os.Getenv("DATABASE_URL"),
|
||||||
|
MaxConns: 25,
|
||||||
|
MinConns: 5,
|
||||||
|
})
|
||||||
|
defer pgm.ClosePool()
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Safe Query Execution
|
||||||
|
Check for no rows using the helper:
|
||||||
|
```go
|
||||||
|
err := users.User.Select(users.Email).Where(users.ID.Eq(id)).First(ctx, &email)
|
||||||
|
if pgm.IsNotFound(err) {
|
||||||
|
// Handle not found case
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// Handle other errors
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Validation Requirements
|
||||||
|
- UPDATE requires at least one `Set()` call
|
||||||
|
- INSERT requires at least one `Set()` call
|
||||||
|
- DELETE without `Where()` deletes ALL rows (dangerous!)
|
||||||
|
|
||||||
|
All execution methods validate these requirements and return descriptive errors.
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- **sync.Pool**: Reuses string builders across queries
|
||||||
|
- **Pre-allocation**: Queries pre-calculate buffer sizes
|
||||||
|
- **Zero reflection**: Direct field access, no runtime type inspection
|
||||||
|
- **pgxpool**: Leverages jackc/pgx's efficient connection pooling
|
||||||
|
- **Direct scanning**: Users scan into their own types, no intermediate mapping
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- All queries use parameterized statements (PostgreSQL numbered parameters)
|
||||||
|
- Field validation in functions like `DateTrunc()` uses allowlists
|
||||||
|
- SQL string escaping in `escapeSQLString()` for literal values
|
||||||
|
- Identifier validation via regex in `validateSQLIdentifier()`
|
||||||
|
- Connection pool configuration validates against negative values and invalid ranges
|
||||||
25
Makefile
25
Makefile
@@ -1,4 +1,25 @@
|
|||||||
|
.PHONY: run bench-select test build install
|
||||||
|
|
||||||
|
# Version can be set via: make build VERSION=v1.2.3
|
||||||
|
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
|
|
||||||
run:
|
run:
|
||||||
go run ./cmd -o ./example/db ./example/schema.sql
|
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||||
|
|
||||||
bench-select:
|
bench-select:
|
||||||
go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
go test ./playground -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./playground
|
||||||
|
|
||||||
|
# Build with version information
|
||||||
|
build:
|
||||||
|
go build -ldflags "-X main.version=$(VERSION)" -o pgm ./cmd
|
||||||
|
|
||||||
|
# Install to GOPATH/bin with version
|
||||||
|
install:
|
||||||
|
go install -ldflags "-X main.version=$(VERSION)" ./cmd
|
||||||
|
|
||||||
|
# Show current version
|
||||||
|
version:
|
||||||
|
@echo "Version: $(VERSION)"
|
||||||
|
|||||||
807
README.md
807
README.md
@@ -1,81 +1,794 @@
|
|||||||
# pgm - PostgreSQL Query Mapper
|
# pgm - PostgreSQL Query Mapper
|
||||||
|
|
||||||
A lightweight ORM built on top of [jackc/pgx](https://github.com/jackc/pgx) database connection pool.
|
[](https://pkg.go.dev/code.patial.tech/go/pgm)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
## ORMs I Like in the Go Ecosystem
|
A lightweight, type-safe PostgreSQL query builder for Go, built on top of [jackc/pgx](https://github.com/jackc/pgx). **pgm** generates Go code from your SQL schema, enabling you to write SQL queries with compile-time safety and autocompletion support.
|
||||||
|
|
||||||
- [ent](https://github.com/ent/ent)
|
## Features
|
||||||
- [sqlc](https://github.com/sqlc-dev/sqlc)
|
|
||||||
|
|
||||||
## Why Not Use `ent`?
|
- **Type-safe queries** - Column and table names are validated at compile time
|
||||||
|
- **Zero reflection** - Fast performance with no runtime reflection overhead
|
||||||
|
- **SQL schema-based** - Generate Go code directly from your SQL schema files
|
||||||
|
- **Fluent API** - Intuitive query builder with method chaining
|
||||||
|
- **Transaction support** - First-class support for pgx transactions
|
||||||
|
- **Full-text search** - Built-in PostgreSQL full-text search helpers
|
||||||
|
- **Connection pooling** - Leverages pgx connection pool for optimal performance
|
||||||
|
- **Minimal code generation** - Only generates what you need, no bloat
|
||||||
|
|
||||||
`ent` is a feature-rich ORM with schema definition, automatic migrations, integration with `gqlgen` (GraphQL server), and more. It provides nearly everything you could want in an ORM.
|
## Table of Contents
|
||||||
|
|
||||||
However, it can be overkill. The generated code supports a wide range of features, many of which you may not use, significantly increasing the compiled binary size.
|
- [Why pgm?](#why-pgm)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Usage Examples](#usage-examples)
|
||||||
|
- [SELECT Queries](#select-queries)
|
||||||
|
- [INSERT Queries](#insert-queries)
|
||||||
|
- [UPDATE Queries](#update-queries)
|
||||||
|
- [DELETE Queries](#delete-queries)
|
||||||
|
- [Joins](#joins)
|
||||||
|
- [Transactions](#transactions)
|
||||||
|
- [Full-Text Search](#full-text-search)
|
||||||
|
- [CLI Tool](#cli-tool)
|
||||||
|
- [API Documentation](#api-documentation)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
## Why Not Use `sqlc`?
|
## Why pgm?
|
||||||
|
|
||||||
`sqlc` is a great tool, but it often feels like the database layer introduces its own models. This forces you to either map your application’s models to these database models or use the database models directly, which may not align with your application’s design.
|
### The Problem with Existing ORMs
|
||||||
|
|
||||||
## Issues with Existing ORMs
|
While Go has excellent ORMs like [ent](https://github.com/ent/ent) and [sqlc](https://github.com/sqlc-dev/sqlc), they come with tradeoffs:
|
||||||
|
|
||||||
Here are some common pain points with ORMs:
|
**ent** - Feature-rich but heavy:
|
||||||
|
- Generates extensive code for features you may never use
|
||||||
|
- Significantly increases binary size
|
||||||
|
- Complex schema definition in Go instead of SQL
|
||||||
|
- Auto-migrations can obscure actual database schema
|
||||||
|
|
||||||
- **Auto Migrations**: Many ORMs either lack robust migration support or implement complex methods for simple schema changes. This can obscure the database schema, making it harder to understand and maintain. A database schema should be defined in clear SQL statements that can be tested in a SQL query editor. Tools like [dbmate](https://github.com/amacneil/dbmate) provide a mature solution for managing migrations, usable via CLI or in code.
|
**sqlc** - Great tool, but:
|
||||||
|
- Creates separate database models, forcing model mapping
|
||||||
|
- Query results require their own generated types
|
||||||
|
- Less flexibility in dynamic query building
|
||||||
|
|
||||||
- **Excessive Code Generation**: ORMs often generate excessive code for various conditions and scenarios, much of which goes unused.
|
### The pgm Approach
|
||||||
|
|
||||||
- **Generated Models for Queries**: Auto-generated models for `SELECT` queries force you to either adopt them or map them to your application’s models, adding complexity.
|
**pgm** takes a hybrid approach:
|
||||||
|
|
||||||
## A Hybrid Approach: Plain SQL Queries with `pgm`
|
✅ **Schema as SQL** - Define your database schema in pure SQL, where it belongs
|
||||||
|
✅ **Minimal generation** - Only generates table and column definitions
|
||||||
|
✅ **Your models** - Use your own application models, no forced abstractions
|
||||||
|
✅ **Type safety** - Catch schema changes at compile time
|
||||||
|
✅ **SQL power** - Full control over your queries with a fluent API
|
||||||
|
✅ **Migration-friendly** - Use mature tools like [dbmate](https://github.com/amacneil/dbmate) for migrations
|
||||||
|
|
||||||
Plain SQL queries are not inherently bad but come with challenges:
|
## Installation
|
||||||
|
|
||||||
- **Schema Change Detection**: Changes in the database schema are not easily detected.
|
|
||||||
- **SQL Injection Risks**: Without parameterized queries, SQL injection becomes a concern.
|
|
||||||
|
|
||||||
`pgm` addresses these issues by providing a lightweight CLI tool that generates Go files for your database schema. These files help you write SQL queries while keeping track of schema changes, avoiding hardcoded table and column names.
|
|
||||||
|
|
||||||
## Generating `pgm` Schema Files
|
|
||||||
|
|
||||||
Run the following command to generate schema files:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go run code.partial.tech/go/pgm/cmd -o ./db ./schema.sql
|
go get code.patial.tech/go/pgm
|
||||||
```
|
```
|
||||||
once you have the schama files created you can use `pgm` as
|
|
||||||
|
Install the CLI tool for schema code generation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go install code.patial.tech/go/pgm/cmd@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Building from Source
|
||||||
|
|
||||||
|
Build with automatic version detection (uses git tags):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build with version from git tags
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Build with specific version
|
||||||
|
make build VERSION=v1.2.3
|
||||||
|
|
||||||
|
# Install to GOPATH/bin
|
||||||
|
make install
|
||||||
|
|
||||||
|
# Or build manually with version
|
||||||
|
go build -ldflags "-X main.version=v1.2.3" -o pgm ./cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
Check the version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pgm -version
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Create Your Schema
|
||||||
|
|
||||||
|
Create a SQL schema file `schema.sql` or use the one created by [dbmate](https://github.com/amacneil/dbmate):
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
email VARCHAR(255) UNIQUE NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE posts (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id),
|
||||||
|
title VARCHAR(500) NOT NULL,
|
||||||
|
content TEXT,
|
||||||
|
published BOOLEAN DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Generate Go Code
|
||||||
|
|
||||||
|
Run the pgm CLI tool:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pgm -o ./db ./schema.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
This generates Go files for each table in `./db/`:
|
||||||
|
- `db/users/users.go` - Table and column definitions for users
|
||||||
|
- `db/posts/posts.go` - Table and column definitions for posts
|
||||||
|
|
||||||
|
### 3. Use in Your Code
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"code.partial.tech/go/pgm"
|
"context"
|
||||||
"myapp/db/user" // scham create by pgm/cmd
|
"log"
|
||||||
|
|
||||||
|
"code.patial.tech/go/pgm"
|
||||||
|
"yourapp/db/users"
|
||||||
|
"yourapp/db/posts"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MyModel struct {
|
|
||||||
ID string
|
|
||||||
Email string
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
println("Initializing pgx connection pool")
|
// Initialize connection pool
|
||||||
pgm.InitPool(pgm.Config{
|
pgm.InitPool(pgm.Config{
|
||||||
ConnString: url,
|
ConnString: "postgres://user:pass@localhost:5432/dbname",
|
||||||
|
MaxConns: 25,
|
||||||
|
MinConns: 5,
|
||||||
})
|
})
|
||||||
|
defer pgm.ClosePool() // Ensure graceful shutdown
|
||||||
// Select query to fetch the first record
|
|
||||||
// Assumes the schema is defined in the "db" package with a User table
|
ctx := context.Background()
|
||||||
var v MyModel
|
|
||||||
err := db.User.Select(user.ID, user.Email).
|
// Query a user
|
||||||
Where(user.Email.Like("anki%")).
|
var email string
|
||||||
First(context.TODO(), &v.ID, &v.Email)
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
First(ctx, &email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
println("Error:", err.Error())
|
log.Fatal(err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
println("User email:", v.Email)
|
log.Printf("User email: %s", email)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Important: Query Builder Lifecycle
|
||||||
|
|
||||||
|
### ✅ Conditional Building (CORRECT)
|
||||||
|
|
||||||
|
Query builders are **mutable by design** to support conditional query building:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ CORRECT - Conditional building pattern
|
||||||
|
query := users.User.Select(users.ID, users.Email, users.Name)
|
||||||
|
|
||||||
|
// Add conditions based on filters
|
||||||
|
if nameFilter != "" {
|
||||||
|
query = query.Where(users.Name.Like("%" + nameFilter + "%"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusFilter > 0 {
|
||||||
|
query = query.Where(users.Status.Eq(statusFilter))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sortByName {
|
||||||
|
query = query.OrderBy(users.Name.Asc())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the final query with all accumulated conditions
|
||||||
|
err := query.First(ctx, &id, &email, &name)
|
||||||
|
```
|
||||||
|
|
||||||
|
**This is the intended use!** The builder accumulates your conditions, which is powerful and flexible.
|
||||||
|
|
||||||
|
### ❌ Unintentional Reuse (INCORRECT)
|
||||||
|
|
||||||
|
Don't try to create a "base query" and reuse it for **multiple different queries**:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ WRONG - Trying to reuse for multiple separate queries
|
||||||
|
baseQuery := users.User.Select(users.ID, users.Email)
|
||||||
|
|
||||||
|
// First query - adds ID condition
|
||||||
|
baseQuery.Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||||
|
|
||||||
|
// Second query - ALSO has ID=1 from above PLUS Status=2!
|
||||||
|
baseQuery.Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||||
|
// This executes: WHERE users.id = 1 AND users.status = 2 (WRONG!)
|
||||||
|
|
||||||
|
// ✅ CORRECT - Each separate query gets its own builder
|
||||||
|
users.User.Select(users.ID, users.Email).Where(users.ID.Eq(1)).First(ctx, &id1, &email1)
|
||||||
|
users.User.Select(users.ID, users.Email).Where(users.Status.Eq(2)).First(ctx, &id2, &email2)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why?** Query builders are mutable and accumulate state. Each method call modifies the builder, so reusing the same builder causes conditions to stack up.
|
||||||
|
|
||||||
|
### Thread Safety
|
||||||
|
|
||||||
|
⚠️ **Query builders are NOT thread-safe** and must not be shared across goroutines:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ CORRECT - Each goroutine creates its own query
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
var email string
|
||||||
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq(id)).
|
||||||
|
First(ctx, &email)
|
||||||
|
// Process result...
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ❌ WRONG - Sharing query builder across goroutines
|
||||||
|
baseQuery := users.User.Select(users.Email)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
var email string
|
||||||
|
baseQuery.Where(users.ID.Eq(id)).First(ctx, &email)
|
||||||
|
// RACE CONDITION! Multiple goroutines modifying shared state
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Thread-Safe Components:**
|
||||||
|
- ✅ Connection Pool - Safe for concurrent use
|
||||||
|
- ✅ Table objects - Safe to share
|
||||||
|
- ❌ Query builders - Create new instance per goroutine
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### SELECT Queries
|
||||||
|
|
||||||
|
#### Basic Select
|
||||||
|
|
||||||
|
```go
|
||||||
|
var user struct {
|
||||||
|
ID string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||||
|
Where(users.Email.Eq("john@example.com")).
|
||||||
|
First(ctx, &user.ID, &user.Email, &user.Name)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Select with Multiple Conditions
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Select(users.ID, users.Email).
|
||||||
|
Where(
|
||||||
|
users.Email.Like("john%"),
|
||||||
|
users.CreatedAt.Gt(time.Now().AddDate(0, -1, 0)),
|
||||||
|
).
|
||||||
|
OrderBy(users.CreatedAt.Desc()).
|
||||||
|
Limit(10).
|
||||||
|
First(ctx, &user.ID, &user.Email)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Select All with Callback
|
||||||
|
|
||||||
|
```go
|
||||||
|
var userList []User
|
||||||
|
|
||||||
|
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||||
|
Where(users.Name.Like("J%")).
|
||||||
|
OrderBy(users.Name.Asc()).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var u User
|
||||||
|
if err := row.Scan(&u.ID, &u.Email, &u.Name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userList = append(userList, u)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Pagination
|
||||||
|
|
||||||
|
```go
|
||||||
|
page := 2
|
||||||
|
pageSize := 20
|
||||||
|
|
||||||
|
err := users.User.Select(users.ID, users.Email).
|
||||||
|
OrderBy(users.CreatedAt.Desc()).
|
||||||
|
Limit(pageSize).
|
||||||
|
Offset((page - 1) * pageSize).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
// Process rows
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Grouping and Having
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Select(posts.UserID, pgm.Count(posts.ID)).
|
||||||
|
GroupBy(posts.UserID).
|
||||||
|
Having(pgm.Count(posts.ID).Gt(5)).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var userID string
|
||||||
|
var postCount int
|
||||||
|
return row.Scan(&userID, &postCount)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### INSERT Queries
|
||||||
|
|
||||||
|
#### Simple Insert
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
Set(users.CreatedAt, pgm.PgTimeNow()).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Insert with Map
|
||||||
|
|
||||||
|
```go
|
||||||
|
data := map[pgm.Field]any{
|
||||||
|
users.Email: "jane@example.com",
|
||||||
|
users.Name: "Jane Doe",
|
||||||
|
users.CreatedAt: pgm.PgTimeNow(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := users.User.Insert().
|
||||||
|
SetMap(data).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Insert with RETURNING
|
||||||
|
|
||||||
|
```go
|
||||||
|
var newID string
|
||||||
|
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
Returning(users.ID).
|
||||||
|
First(ctx, &newID)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Upsert (INSERT ... ON CONFLICT)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Do nothing on conflict
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
OnConflict(users.Email).
|
||||||
|
DoNothing().
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
// Update on conflict
|
||||||
|
err := users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe Updated").
|
||||||
|
OnConflict(users.Email).
|
||||||
|
DoUpdate(users.Name).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### UPDATE Queries
|
||||||
|
|
||||||
|
#### Simple Update
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Update().
|
||||||
|
Set(users.Name, "John Smith").
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Update Multiple Fields
|
||||||
|
|
||||||
|
```go
|
||||||
|
updates := map[pgm.Field]any{
|
||||||
|
users.Name: "John Smith",
|
||||||
|
users.Email: "john.smith@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := users.User.Update().
|
||||||
|
SetMap(updates).
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Conditional Update
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Update().
|
||||||
|
Set(users.Name, "Updated Name").
|
||||||
|
Where(
|
||||||
|
users.Email.Like("%@example.com"),
|
||||||
|
users.CreatedAt.Lt(time.Now().AddDate(-1, 0, 0)),
|
||||||
|
).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### DELETE Queries
|
||||||
|
|
||||||
|
#### Simple Delete
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Delete().
|
||||||
|
Where(users.ID.Eq("some-uuid")).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Conditional Delete
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Delete().
|
||||||
|
Where(
|
||||||
|
posts.Published.Eq(false),
|
||||||
|
posts.CreatedAt.Lt(time.Now().AddDate(0, 0, -30)),
|
||||||
|
).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Joins
|
||||||
|
|
||||||
|
#### Inner Join
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Select(posts.Title, users.Name).
|
||||||
|
Join(users.User, posts.UserID, users.ID).
|
||||||
|
Where(users.Email.Eq("john@example.com")).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var title, userName string
|
||||||
|
return row.Scan(&title, &userName)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Left Join
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := users.User.Select(users.Name, posts.Title).
|
||||||
|
LeftJoin(posts.Post, users.ID, posts.UserID).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
var userName, postTitle string
|
||||||
|
return row.Scan(&userName, &postTitle)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Join with Additional Conditions
|
||||||
|
|
||||||
|
```go
|
||||||
|
err := posts.Post.Select(posts.Title, users.Name).
|
||||||
|
Join(users.User, posts.UserID, users.ID, users.Email.Like("%@example.com")).
|
||||||
|
Where(posts.Published.Eq(true)).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
// Process rows
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transactions
|
||||||
|
|
||||||
|
#### Basic Transaction
|
||||||
|
|
||||||
|
```go
|
||||||
|
tx, err := pgm.BeginTx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
// Insert user
|
||||||
|
var userID string
|
||||||
|
err = users.User.Insert().
|
||||||
|
Set(users.Email, "jane@example.com").
|
||||||
|
Set(users.Name, "Jane Doe").
|
||||||
|
Returning(users.ID).
|
||||||
|
FirstTx(ctx, tx, &userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert post
|
||||||
|
err = posts.Post.Insert().
|
||||||
|
Set(posts.UserID, userID).
|
||||||
|
Set(posts.Title, "My First Post").
|
||||||
|
Set(posts.Content, "Hello, World!").
|
||||||
|
ExecTx(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit transaction
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Full-Text Search
|
||||||
|
|
||||||
|
PostgreSQL full-text search helpers:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Search with AND operator (all terms must match)
|
||||||
|
searchQuery := pgm.TsAndQuery("golang database")
|
||||||
|
// Result: "golang & database"
|
||||||
|
|
||||||
|
// Search with prefix matching
|
||||||
|
searchQuery := pgm.TsPrefixAndQuery("gol data")
|
||||||
|
// Result: "gol:* & data:*"
|
||||||
|
|
||||||
|
// Search with OR operator (any term matches)
|
||||||
|
searchQuery := pgm.TsOrQuery("golang rust")
|
||||||
|
// Result: "golang | rust"
|
||||||
|
|
||||||
|
// Prefix OR search
|
||||||
|
searchQuery := pgm.TsPrefixOrQuery("go ru")
|
||||||
|
// Result: "go:* | ru:*"
|
||||||
|
|
||||||
|
// Use in query (assuming you have a tsvector column)
|
||||||
|
err := posts.Post.Select(posts.Title, posts.Content).
|
||||||
|
Where(posts.SearchVector.Match(pgm.TsPrefixAndQuery(searchTerm))).
|
||||||
|
OrderBy(posts.CreatedAt.Desc()).
|
||||||
|
All(ctx, func(row pgm.RowScanner) error {
|
||||||
|
// Process results
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI Tool
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pgm -o <output_directory> <schema.sql>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-o string Output directory path (required)
|
||||||
|
-version Show version information
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Generate from a single schema file
|
||||||
|
pgm -o ./db ./schema.sql
|
||||||
|
|
||||||
|
# Generate from concatenated migrations
|
||||||
|
cat migrations/*.sql > /tmp/schema.sql && pgm -o ./db /tmp/schema.sql
|
||||||
|
|
||||||
|
# Check version
|
||||||
|
pgm -version
|
||||||
|
```
|
||||||
|
|
||||||
|
### Known Limitations
|
||||||
|
|
||||||
|
The CLI tool uses a regex-based SQL parser with the following limitations:
|
||||||
|
|
||||||
|
- ❌ Multi-line comments `/* */` are not supported
|
||||||
|
- ❌ Complex data types (arrays, JSON, JSONB) may not parse correctly
|
||||||
|
- ❌ Quoted identifiers with special characters may fail
|
||||||
|
- ❌ Advanced PostgreSQL features (PARTITION BY, INHERITS) not supported
|
||||||
|
- ❌ Some constraints (CHECK, EXCLUDE) are not parsed
|
||||||
|
|
||||||
|
**Workarounds:**
|
||||||
|
- Use simple CREATE TABLE statements
|
||||||
|
- Avoid complex PostgreSQL-specific syntax in schema files
|
||||||
|
- Split complex schemas into multiple simple statements
|
||||||
|
- Remove comments before running the generator
|
||||||
|
|
||||||
|
For complex schemas, consider contributing a more robust parser or using a proper SQL parser library.
|
||||||
|
|
||||||
|
### Generated Code Structure
|
||||||
|
|
||||||
|
For a table named `users`, pgm generates:
|
||||||
|
|
||||||
|
```
|
||||||
|
db/
|
||||||
|
└── users/
|
||||||
|
└── users.go
|
||||||
|
```
|
||||||
|
|
||||||
|
The generated file contains:
|
||||||
|
- Generated code header with version and timestamp
|
||||||
|
- Table definition (`User`)
|
||||||
|
- Column field definitions (`ID`, `Email`, `Name`, etc.)
|
||||||
|
- Type-safe query builders (`Select()`, `Insert()`, `Update()`, `Delete()`)
|
||||||
|
|
||||||
|
**Example header:**
|
||||||
|
```go
|
||||||
|
// Code generated by code.patial.tech/go/pgm/cmd v1.2.3 on 2025-01-27 15:04:05 DO NOT EDIT.
|
||||||
|
```
|
||||||
|
|
||||||
|
The version in generated files helps track which version of the CLI tool was used, making it easier to identify when regeneration is needed after upgrades.
|
||||||
|
|
||||||
|
## API Documentation
|
||||||
|
|
||||||
|
### Connection Pool
|
||||||
|
|
||||||
|
#### InitPool
|
||||||
|
|
||||||
|
Initialize the connection pool (must be called once at startup):
|
||||||
|
|
||||||
|
```go
|
||||||
|
pgm.InitPool(pgm.Config{
|
||||||
|
ConnString: "postgres://...",
|
||||||
|
MaxConns: 25,
|
||||||
|
MinConns: 5,
|
||||||
|
MaxConnLifetime: time.Hour,
|
||||||
|
MaxConnIdleTime: time.Minute * 30,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Configuration Validation:**
|
||||||
|
- MinConns cannot be greater than MaxConns
|
||||||
|
- Connection counts cannot be negative
|
||||||
|
- Connection string is required
|
||||||
|
|
||||||
|
#### ClosePool
|
||||||
|
|
||||||
|
Close the connection pool gracefully (call during application shutdown):
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
pgm.InitPool(pgm.Config{
|
||||||
|
ConnString: "postgres://...",
|
||||||
|
})
|
||||||
|
defer pgm.ClosePool() // Ensures proper cleanup
|
||||||
|
|
||||||
|
// Your application code
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### GetPool
|
||||||
|
|
||||||
|
Get the underlying pgx pool:
|
||||||
|
|
||||||
|
```go
|
||||||
|
pool := pgm.GetPool()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Query Conditions
|
||||||
|
|
||||||
|
Available condition methods on fields:
|
||||||
|
|
||||||
|
- `Eq(value)` - Equal to
|
||||||
|
- `NotEq(value)` - Not equal to
|
||||||
|
- `Gt(value)` - Greater than
|
||||||
|
- `Gte(value)` - Greater than or equal to
|
||||||
|
- `Lt(value)` - Less than
|
||||||
|
- `Lte(value)` - Less than or equal to
|
||||||
|
- `Like(pattern)` - LIKE pattern match
|
||||||
|
- `ILike(pattern)` - Case-insensitive LIKE
|
||||||
|
- `In(values...)` - IN list
|
||||||
|
- `NotIn(values...)` - NOT IN list
|
||||||
|
- `IsNull()` - IS NULL
|
||||||
|
- `IsNotNull()` - IS NOT NULL
|
||||||
|
- `Between(start, end)` - BETWEEN range
|
||||||
|
|
||||||
|
### Field Methods
|
||||||
|
|
||||||
|
- `Asc()` - Sort ascending
|
||||||
|
- `Desc()` - Sort descending
|
||||||
|
- `Name()` - Get column name
|
||||||
|
- `String()` - Get fully qualified name (table.column)
|
||||||
|
|
||||||
|
### Utilities
|
||||||
|
|
||||||
|
- `pgm.PgTime(t time.Time)` - Convert Go time to PostgreSQL timestamptz
|
||||||
|
- `pgm.PgTimeNow()` - Current time as PostgreSQL timestamptz
|
||||||
|
- `pgm.IsNotFound(err)` - Check if error is "no rows found"
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use transactions for related operations** - Ensure data consistency
|
||||||
|
2. **Define schema in SQL** - Use migration tools like dbmate for schema management
|
||||||
|
3. **Regenerate after schema changes** - Run the CLI tool after any schema modifications
|
||||||
|
4. **Use your own models** - Don't let the database dictate your domain models
|
||||||
|
5. **Handle pgx.ErrNoRows** - Use `pgm.IsNotFound(err)` for cleaner error checking
|
||||||
|
6. **Always use context with timeouts** - Prevent queries from running indefinitely
|
||||||
|
7. **Validate UPDATE queries** - Ensure Set() is called before Exec()
|
||||||
|
8. **Be careful with DELETE** - Always use Where() unless you want to delete all rows
|
||||||
|
|
||||||
|
## Important Safety Notes
|
||||||
|
|
||||||
|
### ⚠️ DELETE Operations
|
||||||
|
|
||||||
|
DELETE without WHERE clause will delete ALL rows in the table:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ DANGEROUS - Deletes ALL rows!
|
||||||
|
users.User.Delete().Exec(ctx)
|
||||||
|
|
||||||
|
// ✅ SAFE - Deletes specific rows
|
||||||
|
users.User.Delete().Where(users.ID.Eq("user-id")).Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚠️ UPDATE Operations
|
||||||
|
|
||||||
|
UPDATE requires at least one Set() call:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ ERROR - No columns to update
|
||||||
|
users.User.Update().Where(users.ID.Eq(1)).Exec(ctx)
|
||||||
|
// Returns: "update query has no columns to update"
|
||||||
|
|
||||||
|
// ✅ CORRECT
|
||||||
|
users.User.Update().
|
||||||
|
Set(users.Name, "New Name").
|
||||||
|
Where(users.ID.Eq(1)).
|
||||||
|
Exec(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚠️ Query Timeouts
|
||||||
|
|
||||||
|
Always use context with timeout to prevent hanging queries:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ✅ RECOMMENDED
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := users.User.Select(users.Email).
|
||||||
|
Where(users.ID.Eq("some-id")).
|
||||||
|
First(ctx, &email)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚠️ Connection String Security
|
||||||
|
|
||||||
|
Never log or expose database connection strings as they contain credentials. The library does not sanitize connection strings in error messages.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
**pgm** is designed for performance:
|
||||||
|
|
||||||
|
- Zero reflection overhead
|
||||||
|
- Efficient string building with sync.Pool
|
||||||
|
- Leverages pgx's high-performance connection pooling
|
||||||
|
- Minimal allocations in query building
|
||||||
|
- Direct scanning into your types
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Go 1.20 or higher
|
||||||
|
- PostgreSQL 12 or higher
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|
||||||
|
## Author
|
||||||
|
|
||||||
|
|
||||||
|
**Ankit Patial** - [Patial Tech](https://code.patial.tech)
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
Built on top of the excellent [jackc/pgx](https://github.com/jackc/pgx) library.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Made with ❤️ for the Go community**
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import (
|
|||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// version can be set at build time using:
|
||||||
|
// go build -ldflags "-X main.version=v1.2.3" ./cmd
|
||||||
|
var version = "dev"
|
||||||
|
|
||||||
func generate(scheamPath, outDir string) error {
|
func generate(scheamPath, outDir string) error {
|
||||||
// read schame.sql
|
// read schame.sql
|
||||||
f, err := os.ReadFile(scheamPath)
|
f, err := os.ReadFile(scheamPath)
|
||||||
@@ -29,14 +33,16 @@ func generate(scheamPath, outDir string) error {
|
|||||||
|
|
||||||
// Output dir, create if not exists.
|
// Output dir, create if not exists.
|
||||||
if _, err := os.Stat(outDir); os.IsNotExist(err) {
|
if _, err := os.Stat(outDir); os.IsNotExist(err) {
|
||||||
if err := os.MkdirAll(outDir, 0740); err != nil {
|
if err := os.MkdirAll(outDir, 0755); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// schema.go will hold all tables info
|
// schema.go will hold all tables info
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.\n\n")
|
sb.WriteString(
|
||||||
|
fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s DO NOT EDIT.\n\n", GetVersionString()),
|
||||||
|
)
|
||||||
sb.WriteString(fmt.Sprintf("package %s \n", filepath.Base(outDir)))
|
sb.WriteString(fmt.Sprintf("package %s \n", filepath.Base(outDir)))
|
||||||
sb.WriteString(`
|
sb.WriteString(`
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
@@ -70,14 +76,24 @@ func generate(scheamPath, outDir string) error {
|
|||||||
sb.WriteString("}\n")
|
sb.WriteString("}\n")
|
||||||
}
|
}
|
||||||
modalDir = strings.ToLower(name)
|
modalDir = strings.ToLower(name)
|
||||||
os.Mkdir(filepath.Join(outDir, modalDir), 0740)
|
os.Mkdir(filepath.Join(outDir, modalDir), 0755)
|
||||||
|
|
||||||
if err = writeColFile(t.Table, t.Columns, filepath.Join(outDir, modalDir), caser); err != nil {
|
if err = writeColFile(t.Table, t.Columns, filepath.Join(outDir, modalDir), caser); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(")")
|
sb.WriteString(")")
|
||||||
|
|
||||||
|
sb.WriteString(`
|
||||||
|
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||||
|
t := pgm.Table{
|
||||||
|
Name: tblName,
|
||||||
|
DerivedTable: fromQry,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}`)
|
||||||
|
|
||||||
// Format code before saving
|
// Format code before saving
|
||||||
code, err := formatGoCode(sb.String())
|
code, err := formatGoCode(sb.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -85,16 +101,20 @@ func generate(scheamPath, outDir string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Save file to disk
|
// Save file to disk
|
||||||
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0640)
|
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0644)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Caser) error {
|
func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Caser) error {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.\n\n")
|
sb.WriteString(
|
||||||
|
fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s DO NOT EDIT.\n\n", GetVersionString()),
|
||||||
|
)
|
||||||
sb.WriteString(fmt.Sprintf("package %s\n\n", filepath.Base(outDir)))
|
sb.WriteString(fmt.Sprintf("package %s\n\n", filepath.Base(outDir)))
|
||||||
sb.WriteString(fmt.Sprintf("import %q\n\n", "code.patial.tech/go/pgm"))
|
sb.WriteString(fmt.Sprintf("import %q\n\n", "code.patial.tech/go/pgm"))
|
||||||
sb.WriteString("const (")
|
sb.WriteString("const (")
|
||||||
|
sb.WriteString("\n // All fields in table " + tblName)
|
||||||
|
sb.WriteString(fmt.Sprintf("\n All pgm.Field = %q", tblName+".*"))
|
||||||
var name string
|
var name string
|
||||||
for _, c := range cols {
|
for _, c := range cols {
|
||||||
name = strings.ReplaceAll(c.Name, "_", " ")
|
name = strings.ReplaceAll(c.Name, "_", " ")
|
||||||
@@ -117,7 +137,7 @@ func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Cas
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Save file to disk.
|
// Save file to disk.
|
||||||
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0640)
|
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pluralToSingular converts plural table names to singular forms
|
// pluralToSingular converts plural table names to singular forms
|
||||||
|
|||||||
21
cmd/main.go
21
cmd/main.go
@@ -9,29 +9,42 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageTxt = `Please provide output director and input schema.
|
var (
|
||||||
|
showVersion bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const usageTxt = `Please provide output directory and input schema.
|
||||||
Example:
|
Example:
|
||||||
pgm/cmd -o ./db ./db/schema.sql
|
pgm -o ./db ./schema.sql
|
||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var outDir string
|
var outDir string
|
||||||
flag.StringVar(&outDir, "o", "", "-o as output directory path")
|
flag.StringVar(&outDir, "o", "", "-o as output directory path")
|
||||||
|
flag.BoolVar(&showVersion, "version", false, "show version information")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// Handle version flag
|
||||||
|
if showVersion {
|
||||||
|
fmt.Printf("pgm %s\n", GetVersionString())
|
||||||
|
fmt.Println("PostgreSQL Query Mapper - Schema code generator")
|
||||||
|
fmt.Println("https://code.patial.tech/go/pgm")
|
||||||
|
return
|
||||||
|
}
|
||||||
if len(os.Args) < 4 {
|
if len(os.Args) < 4 {
|
||||||
fmt.Print(usageTxt)
|
fmt.Print(usageTxt)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if outDir == "" {
|
if outDir == "" {
|
||||||
println("missing, -o output directory path")
|
fmt.Fprintln(os.Stderr, "Error: missing output directory path (-o flag required)")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := generate(os.Args[3], outDir); err != nil {
|
if err := generate(os.Args[3], outDir); err != nil {
|
||||||
println(err.Error())
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
15
cmd/parse.go
15
cmd/parse.go
@@ -1,5 +1,20 @@
|
|||||||
// Patial Tech.
|
// Patial Tech.
|
||||||
// Author, Ankit Patial
|
// Author, Ankit Patial
|
||||||
|
//
|
||||||
|
// SQL Parser Limitations:
|
||||||
|
// This is a simple regex-based SQL parser with the following known limitations:
|
||||||
|
// - No support for multi-line comments /* */
|
||||||
|
// - May struggle with complex data types (e.g., arrays, JSON, JSONB)
|
||||||
|
// - No handling of quoted identifiers with special characters
|
||||||
|
// - Advanced features like PARTITION BY, INHERITS not supported
|
||||||
|
// - Does not handle all PostgreSQL-specific syntax
|
||||||
|
// - Constraints like CHECK, EXCLUDE not parsed
|
||||||
|
//
|
||||||
|
// For complex schemas, consider:
|
||||||
|
// 1. Simplifying your schema for generation
|
||||||
|
// 2. Using multiple simple CREATE TABLE statements
|
||||||
|
// 3. Contributing a more robust parser implementation
|
||||||
|
// 4. Using a proper SQL parser library
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
|||||||
67
cmd/version.go
Normal file
67
cmd/version.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Patial Tech.
|
||||||
|
// Author, Ankit Patial
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Version returns the version of the pgm CLI tool.
|
||||||
|
// It tries to detect the version in the following order:
|
||||||
|
// 1. Build-time ldflags (set via: go build -ldflags "-X main.version=v1.2.3")
|
||||||
|
// 2. VCS information from build metadata (git tag/commit)
|
||||||
|
// 3. Falls back to "dev" if no version information is available
|
||||||
|
func Version() string {
|
||||||
|
// If version was set at build time via ldflags
|
||||||
|
if version != "" && version != "dev" {
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get version from build info (Go 1.18+)
|
||||||
|
if info, ok := debug.ReadBuildInfo(); ok {
|
||||||
|
// Check for version in main module
|
||||||
|
if info.Main.Version != "" && info.Main.Version != "(devel)" {
|
||||||
|
return info.Main.Version
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to extract from VCS information
|
||||||
|
var revision, modified string
|
||||||
|
for _, setting := range info.Settings {
|
||||||
|
switch setting.Key {
|
||||||
|
case "vcs.revision":
|
||||||
|
revision = setting.Value
|
||||||
|
case "vcs.modified":
|
||||||
|
modified = setting.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a git revision
|
||||||
|
if revision != "" {
|
||||||
|
// Shorten commit hash to 7 characters
|
||||||
|
if len(revision) > 7 {
|
||||||
|
revision = revision[:7]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add -dirty suffix if modified
|
||||||
|
if modified == "true" {
|
||||||
|
return "dev-" + revision + "-dirty"
|
||||||
|
}
|
||||||
|
return "dev-" + revision
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to dev
|
||||||
|
return "dev"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVersionString returns a formatted version string for display
|
||||||
|
func GetVersionString() string {
|
||||||
|
v := Version()
|
||||||
|
|
||||||
|
// Clean up version string for display
|
||||||
|
v = strings.TrimPrefix(v, "v")
|
||||||
|
|
||||||
|
return "v" + v
|
||||||
|
}
|
||||||
8
go.mod
8
go.mod
@@ -3,14 +3,14 @@ module code.patial.tech/go/pgm
|
|||||||
go 1.24.5
|
go 1.24.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jackc/pgx/v5 v5.7.5
|
github.com/jackc/pgx/v5 v5.7.6
|
||||||
golang.org/x/text v0.27.0
|
golang.org/x/text v0.31.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
golang.org/x/crypto v0.40.0 // indirect
|
golang.org/x/crypto v0.45.0 // indirect
|
||||||
golang.org/x/sync v0.16.0 // indirect
|
golang.org/x/sync v0.18.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -7,6 +7,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw
|
|||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
||||||
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||||
|
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||||
|
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
@@ -18,10 +20,16 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||||
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||||
|
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||||
|
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|||||||
308
pgm.go
308
pgm.go
@@ -1,187 +1,197 @@
|
|||||||
// Patial Tech.
|
// pgm
|
||||||
// Author, Ankit Patial
|
//
|
||||||
|
// A simple PG string query builder
|
||||||
|
//
|
||||||
|
// Author: Ankit Patial
|
||||||
|
|
||||||
package pgm
|
package pgm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Table in database
|
var (
|
||||||
type Table struct {
|
poolPGX atomic.Pointer[pgxpool.Pool]
|
||||||
Name string
|
ErrConnStringMissing = errors.New("connection string is empty")
|
||||||
PK []string
|
)
|
||||||
FieldCount uint16
|
|
||||||
debug bool
|
// Common errors returned by pgm operations
|
||||||
}
|
var (
|
||||||
|
ErrInitTX = errors.New("failed to init db.tx")
|
||||||
// Debug when set true will print generated query string in stdout
|
ErrCommitTX = errors.New("failed to commit db.tx")
|
||||||
func (t Table) Debug() Clause {
|
ErrNoRows = errors.New("no data found")
|
||||||
t.debug = true
|
)
|
||||||
return t
|
|
||||||
|
// Config holds the configuration for initializing the connection pool.
|
||||||
|
// All fields except ConnString are optional and will use pgx defaults if not set.
|
||||||
|
type Config struct {
|
||||||
|
ConnString string
|
||||||
|
MaxConns int32
|
||||||
|
MinConns int32
|
||||||
|
MaxConnLifetime time.Duration
|
||||||
|
MaxConnIdleTime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitPool initializes the connection pool with the provided configuration.
|
||||||
|
// It validates the configuration and panics if invalid.
|
||||||
|
// This function should be called once at application startup.
|
||||||
//
|
//
|
||||||
// Field ==>
|
// Example:
|
||||||
//
|
//
|
||||||
|
// pgm.InitPool(pgm.Config{
|
||||||
|
// ConnString: "postgres://user:pass@localhost/dbname",
|
||||||
|
// MaxConns: 100,
|
||||||
|
// MinConns: 5,
|
||||||
|
// })
|
||||||
|
func InitPool(conf Config) {
|
||||||
|
if conf.ConnString == "" {
|
||||||
|
panic(ErrConnStringMissing)
|
||||||
|
}
|
||||||
|
|
||||||
// Field related to a table
|
// Validate configuration
|
||||||
type Field string
|
if conf.MaxConns > 0 && conf.MinConns > 0 && conf.MinConns > conf.MaxConns {
|
||||||
|
panic(fmt.Errorf("MinConns (%d) cannot be greater than MaxConns (%d)", conf.MinConns, conf.MaxConns))
|
||||||
|
}
|
||||||
|
|
||||||
func (f Field) Name() string {
|
if conf.MaxConns < 0 || conf.MinConns < 0 {
|
||||||
return strings.Split(string(f), ".")[1]
|
panic(errors.New("connection pool configuration cannot have negative values"))
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := pgxpool.ParseConfig(conf.ConnString)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MaxConns > 0 {
|
||||||
|
cfg.MaxConns = conf.MaxConns // 100
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MinConns > 0 {
|
||||||
|
cfg.MinConns = conf.MinConns // 5
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MaxConnLifetime > 0 {
|
||||||
|
cfg.MaxConnLifetime = conf.MaxConnLifetime // time.Minute * 10
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.MaxConnIdleTime > 0 {
|
||||||
|
cfg.MaxConnIdleTime = conf.MaxConnIdleTime // time.Minute * 5
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := pgxpool.NewWithConfig(context.Background(), cfg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = p.Ping(context.Background()); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
poolPGX.Store(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f Field) String() string {
|
// GetPool returns the initialized connection pool instance.
|
||||||
return string(f)
|
// It panics with a descriptive message if InitPool() has not been called.
|
||||||
|
// This is a fail-fast approach to catch programming errors early.
|
||||||
|
func GetPool() *pgxpool.Pool {
|
||||||
|
p := poolPGX.Load()
|
||||||
|
if p == nil {
|
||||||
|
panic("pgm: connection pool not initialized, call InitPool() first")
|
||||||
|
}
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count fn wrapping of field
|
// ClosePool closes the connection pool gracefully.
|
||||||
func (f Field) Count() Field {
|
// Should be called during application shutdown.
|
||||||
return Field("COUNT(" + f.String() + ")")
|
func ClosePool() {
|
||||||
}
|
if p := poolPGX.Load(); p != nil {
|
||||||
|
p.Close()
|
||||||
// StringEscape will return a empty string for null value
|
poolPGX.Store(nil)
|
||||||
func (f Field) StringEscape() Field {
|
}
|
||||||
return Field("COALESCE(" + f.String() + ", '')")
|
|
||||||
}
|
|
||||||
|
|
||||||
// NumberEscape will return a zero string for null value
|
|
||||||
func (f Field) NumberEscape() Field {
|
|
||||||
return Field("COALESCE(" + f.String() + ", 0)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// BooleanEscape will return a false for null value
|
|
||||||
func (f Field) BooleanEscape() Field {
|
|
||||||
return Field("COALESCE(" + f.String() + ", FALSE)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Avg fn wrapping of field
|
|
||||||
func (f Field) Avg() Field {
|
|
||||||
return Field("AVG(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Sum() Field {
|
|
||||||
return Field("SUM(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Max() Field {
|
|
||||||
return Field("MAX(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Min() Field {
|
|
||||||
return Field("Min(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Lower() Field {
|
|
||||||
return Field("LOWER(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Upper() Field {
|
|
||||||
return Field("UPPER(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Trim() Field {
|
|
||||||
return Field("TRIM(" + f.String() + ")")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) IsNull() Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) IsNotNull() Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Eq is equal
|
|
||||||
func (f Field) Eq(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
|
||||||
func (f Field) EqFold(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) NEq(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Gt(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Lt(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Gte(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Lte(val any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) Like(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) LikeFold(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ILIKE is case-insensitive
|
|
||||||
func (f Field) ILike(val string) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) NotIn(val ...any) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: val, op: " NOT IN($", action: CondActionNeedToClose, len: len(col) + 5}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
|
||||||
col := f.String()
|
|
||||||
return &Cond{Field: col, Val: qry, op: " NOT IN($)", action: CondActionSubQuery}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeginTx begins a new database transaction from the connection pool.
|
||||||
|
// Returns an error if the transaction cannot be started.
|
||||||
|
// Remember to commit or rollback the transaction when done.
|
||||||
//
|
//
|
||||||
// Helper func ==>
|
// Example:
|
||||||
//
|
//
|
||||||
|
// tx, err := pgm.BeginTx(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// defer tx.Rollback(ctx) // rollback on error
|
||||||
|
//
|
||||||
|
// // ... do work ...
|
||||||
|
//
|
||||||
|
// return tx.Commit(ctx)
|
||||||
|
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||||
|
tx, err := poolPGX.Load().Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to begin transaction", "error", err)
|
||||||
|
return nil, fmt.Errorf("failed to open db tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// PgTime as in UTC
|
return tx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotFound checks if an error is a "no rows" error from pgx.
|
||||||
|
// Returns true if the error indicates no rows were found in a query result.
|
||||||
|
func IsNotFound(err error) bool {
|
||||||
|
return errors.Is(err, pgx.ErrNoRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PgTime converts a Go time.Time to PostgreSQL timestamptz type.
|
||||||
|
// The time is stored as-is (preserves timezone information).
|
||||||
func PgTime(t time.Time) pgtype.Timestamptz {
|
func PgTime(t time.Time) pgtype.Timestamptz {
|
||||||
return pgtype.Timestamptz{Time: t, Valid: true}
|
return pgtype.Timestamptz{Time: t, Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PgTimeNow returns the current time as PostgreSQL timestamptz type.
|
||||||
func PgTimeNow() pgtype.Timestamptz {
|
func PgTimeNow() pgtype.Timestamptz {
|
||||||
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
return pgtype.Timestamptz{Time: time.Now(), Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConcatWs(sep string, fields ...Field) Field {
|
// TsAndQuery converts a text search query to use AND operator between terms.
|
||||||
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
// Example: "hello world" becomes "hello & world"
|
||||||
|
func TsAndQuery(q string) string {
|
||||||
|
return strings.Join(strings.Fields(q), " & ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func StringAgg(exp, sep string) Field {
|
// TsPrefixAndQuery converts a text search query to use AND operator with prefix matching.
|
||||||
return Field("string_agg(" + exp + ",'" + sep + "')")
|
// Example: "hello world" becomes "hello:* & world:*"
|
||||||
|
func TsPrefixAndQuery(q string) string {
|
||||||
|
return strings.Join(fieldsWithSufix(q, ":*"), " & ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func StringAggCast(exp, sep string) Field {
|
// TsOrQuery converts a text search query to use OR operator between terms.
|
||||||
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
// Example: "hello world" becomes "hello | world"
|
||||||
|
func TsOrQuery(q string) string {
|
||||||
|
return strings.Join(strings.Fields(q), " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TsPrefixOrQuery converts a text search query to use OR operator with prefix matching.
|
||||||
|
// Example: "hello world" becomes "hello:* | world:*"
|
||||||
|
func TsPrefixOrQuery(q string) string {
|
||||||
|
return strings.Join(fieldsWithSufix(q, ":*"), " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldsWithSufix(v, sufix string) []string {
|
||||||
|
fields := strings.Fields(v)
|
||||||
|
prefixed := make([]string, len(fields))
|
||||||
|
for i, f := range fields {
|
||||||
|
prefixed[i] = f + sufix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefixed
|
||||||
}
|
}
|
||||||
|
|||||||
340
pgm_field.go
Normal file
340
pgm_field.go
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Field related to a table
|
||||||
|
type Field string
|
||||||
|
|
||||||
|
var (
|
||||||
|
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
|
||||||
|
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||||
|
|
||||||
|
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
|
||||||
|
validDateTruncLevels = map[string]bool{
|
||||||
|
"microseconds": true,
|
||||||
|
"milliseconds": true,
|
||||||
|
"second": true,
|
||||||
|
"minute": true,
|
||||||
|
"hour": true,
|
||||||
|
"day": true,
|
||||||
|
"week": true,
|
||||||
|
"month": true,
|
||||||
|
"quarter": true,
|
||||||
|
"year": true,
|
||||||
|
"decade": true,
|
||||||
|
"century": true,
|
||||||
|
"millennium": true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateSQLIdentifier checks if a string is a valid SQL identifier
|
||||||
|
func validateSQLIdentifier(s string) error {
|
||||||
|
if !sqlIdentifierRegex.MatchString(s) {
|
||||||
|
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeSQLString escapes single quotes in a string for SQL
|
||||||
|
func escapeSQLString(s string) string {
|
||||||
|
return strings.ReplaceAll(s, "'", "''")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Name() string {
|
||||||
|
s := string(f)
|
||||||
|
idx := strings.LastIndexByte(s, '.')
|
||||||
|
if idx == -1 {
|
||||||
|
return s // Return as-is if no dot
|
||||||
|
}
|
||||||
|
return s[idx+1:] // Return part after last dot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) String() string {
|
||||||
|
return string(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count function wrapped field
|
||||||
|
func (f Field) Count() Field {
|
||||||
|
return Field("COUNT(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConcatWs creates a CONCAT_WS SQL function.
|
||||||
|
// SECURITY: The sep parameter should only be a constant string, not user input.
|
||||||
|
// Single quotes in sep will be escaped automatically.
|
||||||
|
func ConcatWs(sep string, fields ...Field) Field {
|
||||||
|
escapedSep := escapeSQLString(sep)
|
||||||
|
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringAgg creates a STRING_AGG SQL function.
|
||||||
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||||
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||||
|
func StringAgg(field Field, sep string) Field {
|
||||||
|
escapedSep := escapeSQLString(sep)
|
||||||
|
return Field("string_agg(" + field.String() + ",'" + escapedSep + "')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||||
|
// SECURITY: The field parameter provides type safety by accepting only Field type.
|
||||||
|
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||||
|
func StringAggCast(field Field, sep string) Field {
|
||||||
|
escapedSep := escapeSQLString(sep)
|
||||||
|
return Field("string_agg(cast(" + field.String() + " as varchar),'" + escapedSep + "')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringEscape will wrap field with:
|
||||||
|
//
|
||||||
|
// COALESCE(field, ”)
|
||||||
|
func (f Field) StringEscape() Field {
|
||||||
|
return Field("COALESCE(" + f.String() + ", '')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumberEscape will wrap field with:
|
||||||
|
//
|
||||||
|
// COALESCE(field, 0)
|
||||||
|
func (f Field) NumberEscape() Field {
|
||||||
|
return Field("COALESCE(" + f.String() + ", 0)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BooleanEscape will wrap field with:
|
||||||
|
//
|
||||||
|
// COALESCE(field, FALSE)
|
||||||
|
func (f Field) BooleanEscape() Field {
|
||||||
|
return Field("COALESCE(" + f.String() + ", FALSE)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avg function wrapped field
|
||||||
|
func (f Field) Avg() Field {
|
||||||
|
return Field("AVG(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum function wrapped field
|
||||||
|
func (f Field) Sum() Field {
|
||||||
|
return Field("SUM(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max function wrapped field
|
||||||
|
func (f Field) Max() Field {
|
||||||
|
return Field("MAX(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min function wrapped field
|
||||||
|
func (f Field) Min() Field {
|
||||||
|
return Field("Min(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower function wrapped field
|
||||||
|
func (f Field) Lower() Field {
|
||||||
|
return Field("LOWER(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upper function wrapped field
|
||||||
|
func (f Field) Upper() Field {
|
||||||
|
return Field("UPPER(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim function wrapped field
|
||||||
|
func (f Field) Trim() Field {
|
||||||
|
return Field("TRIM(" + f.String() + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asc suffixed field, supposed to be used with order by
|
||||||
|
func (f Field) Asc() Field {
|
||||||
|
return Field(f.String() + " ASC")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Desc suffixed field, supposed to be used with order by
|
||||||
|
func (f Field) Desc() Field {
|
||||||
|
return Field(f.String() + " DESC")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) RowNumber(as string, extraOrderBy ...Field) Field {
|
||||||
|
return rowNumber(&f, nil, true, as, extraOrderBy...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) RowNumberDesc(as string, extraOrderBy ...Field) Field {
|
||||||
|
return rowNumber(&f, nil, false, as, extraOrderBy...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowNumberPartionBy in ascending order
|
||||||
|
func (f Field) RowNumberPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
|
||||||
|
return rowNumber(&f, &partition, true, as, extraOrderBy...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) RowNumberDescPartionBy(partition Field, as string, extraOrderBy ...Field) Field {
|
||||||
|
return rowNumber(&f, &partition, false, as, extraOrderBy...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rowNumber(f, partition *Field, isAsc bool, as string, extraOrderBy ...Field) Field {
|
||||||
|
// Validate as parameter is a valid SQL identifier
|
||||||
|
if as != "" {
|
||||||
|
if err := validateSQLIdentifier(as); err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var orderBy string
|
||||||
|
if isAsc {
|
||||||
|
orderBy = " ASC"
|
||||||
|
} else {
|
||||||
|
orderBy = " DESC"
|
||||||
|
}
|
||||||
|
|
||||||
|
if as == "" {
|
||||||
|
as = "row_number"
|
||||||
|
}
|
||||||
|
|
||||||
|
col := f.String()
|
||||||
|
|
||||||
|
// Build ORDER BY clause with primary field and extra fields
|
||||||
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
|
||||||
|
sb.WriteString(col)
|
||||||
|
sb.WriteString(orderBy)
|
||||||
|
|
||||||
|
// Add extra ORDER BY fields
|
||||||
|
for _, extra := range extraOrderBy {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
sb.WriteString(extra.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
orderByClause := sb.String()
|
||||||
|
|
||||||
|
if partition != nil {
|
||||||
|
return Field("ROW_NUMBER() OVER (PARTITION BY " + partition.String() + " ORDER BY " + orderByClause + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Field("ROW_NUMBER() OVER (ORDER BY " + orderByClause + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) IsNull() Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, op: " IS NULL", len: len(col) + 8}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) IsNotNull() Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, op: " IS NOT NULL", len: len(col) + 12}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DateTrunc will truncate date or timestamp to specified level of precision
|
||||||
|
//
|
||||||
|
// Level values:
|
||||||
|
// - microseconds, milliseconds, second, minute, hour
|
||||||
|
// - day, week (Monday start), month, quarter, year
|
||||||
|
// - decade, century, millennium
|
||||||
|
func (f Field) DateTrunc(level, as string) Field {
|
||||||
|
// Validate level parameter against allowed values
|
||||||
|
if !validDateTruncLevels[strings.ToLower(level)] {
|
||||||
|
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate as parameter is a valid SQL identifier
|
||||||
|
if err := validateSQLIdentifier(as); err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) TsRank(fieldName, as string) Field {
|
||||||
|
// Validate as parameter is a valid SQL identifier
|
||||||
|
if err := validateSQLIdentifier(as); err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualFold will use LOWER(column_name) = LOWER(val) for comparision
|
||||||
|
func (f Field) EqFold(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " = LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eq is equal
|
||||||
|
func (f Field) Eq(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " = $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) NotEq(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " != $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Gt(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " > $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Lt(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " < $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Gte(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " >= $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Lte(val any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " <= $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Like(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " LIKE $", len: len(f.String()) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) LikeFold(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: "LOWER(" + col + ")", Val: val, op: " LIKE LOWER($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ILIKE is case-insensitive
|
||||||
|
func (f Field) ILike(val string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " ILIKE $", len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) Any(val ...any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " = ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) NotAny(val ...any) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: val, op: " != ANY($", action: CondActionNeedToClose, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotInSubQuery using ANY
|
||||||
|
func (f Field) NotInSubQuery(qry WhereClause) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, Val: qry, op: " != ANY($)", action: CondActionSubQuery}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Field) TsQuery(as string) Conditioner {
|
||||||
|
col := f.String()
|
||||||
|
return &Cond{Field: col, op: " @@ " + as, len: len(col) + 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinFileds(fields []Field) string {
|
||||||
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
for i, f := range fields {
|
||||||
|
if i == 0 {
|
||||||
|
sb.WriteString(f.String())
|
||||||
|
} else {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
sb.WriteString(f.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
382
pgm_field_test.go
Normal file
382
pgm_field_test.go
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestValidateSQLIdentifier tests SQL identifier validation
|
||||||
|
func TestValidateSQLIdentifier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid simple", "column_name", false},
|
||||||
|
{"valid with numbers", "column123", false},
|
||||||
|
{"valid underscore", "_private", false},
|
||||||
|
{"valid mixed", "my_Column_123", false},
|
||||||
|
{"invalid with space", "column name", true},
|
||||||
|
{"invalid with dash", "column-name", true},
|
||||||
|
{"invalid with dot", "table.column", true},
|
||||||
|
{"invalid with quote", "column'name", true},
|
||||||
|
{"invalid with semicolon", "column;DROP", true},
|
||||||
|
{"invalid starts with number", "123column", true},
|
||||||
|
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
|
||||||
|
{"empty string", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateSQLIdentifier(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEscapeSQLString tests SQL string escaping
|
||||||
|
func TestEscapeSQLString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"no quotes", "hello", "hello"},
|
||||||
|
{"single quote", "hello'world", "hello''world"},
|
||||||
|
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
|
||||||
|
{"only quote", "'", "''"},
|
||||||
|
{"empty string", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := escapeSQLString(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
|
||||||
|
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sep string
|
||||||
|
fields []Field
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "safe separator",
|
||||||
|
sep: ", ",
|
||||||
|
fields: []Field{"col1", "col2"},
|
||||||
|
contains: "concat_ws(', ',col1, col2)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped quotes",
|
||||||
|
sep: "', (SELECT password FROM users), '",
|
||||||
|
fields: []Field{"col1"},
|
||||||
|
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single quote",
|
||||||
|
sep: "'",
|
||||||
|
fields: []Field{"col1"},
|
||||||
|
contains: "concat_ws('''',col1)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConcatWs(tt.sep, tt.fields...)
|
||||||
|
if !strings.Contains(string(result), tt.contains) {
|
||||||
|
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
|
||||||
|
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field Field
|
||||||
|
sep string
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "safe parameters",
|
||||||
|
field: Field("column_name"),
|
||||||
|
sep: ", ",
|
||||||
|
contains: "string_agg(column_name,', ')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped quotes in separator",
|
||||||
|
sep: "'; DROP TABLE users; --",
|
||||||
|
field: Field("col"),
|
||||||
|
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := StringAgg(tt.field, tt.sep)
|
||||||
|
if !strings.Contains(string(result), tt.contains) {
|
||||||
|
t.Errorf("StringAgg(%v, %q) = %q, should contain %q", tt.field, tt.sep, result, tt.contains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||||
|
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||||
|
result := StringAggCast(Field("column"), "'; DROP TABLE")
|
||||||
|
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDateTruncValidation tests DateTrunc level validation
|
||||||
|
func TestDateTruncValidation(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
|
||||||
|
validLevels := []string{
|
||||||
|
"microseconds", "milliseconds", "second", "minute", "hour",
|
||||||
|
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, level := range validLevels {
|
||||||
|
t.Run("valid_"+level, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.DateTrunc(level, "truncated")
|
||||||
|
if !strings.Contains(string(result), level) {
|
||||||
|
t.Errorf("DateTrunc result should contain level %q", level)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case-insensitive
|
||||||
|
t.Run("case_insensitive", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("DateTrunc should accept uppercase level")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.DateTrunc("MONTH", "truncated")
|
||||||
|
if !strings.Contains(strings.ToLower(string(result)), "month") {
|
||||||
|
t.Errorf("DateTrunc should normalize case")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
|
||||||
|
func TestDateTruncInvalidLevel(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
|
||||||
|
invalidLevels := []string{
|
||||||
|
"invalid",
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"day; DELETE FROM",
|
||||||
|
"",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, level := range invalidLevels {
|
||||||
|
t.Run("invalid_"+level, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.DateTrunc(level, "truncated")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
|
||||||
|
func TestDateTruncInvalidAlias(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
|
||||||
|
invalidAliases := []string{
|
||||||
|
"alias name",
|
||||||
|
"alias-name",
|
||||||
|
"'; DROP TABLE",
|
||||||
|
"123alias",
|
||||||
|
"alias.name",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range invalidAliases {
|
||||||
|
t.Run("invalid_alias_"+alias, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.DateTrunc("day", alias)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTsRankValidation tests TsRank alias validation
|
||||||
|
func TestTsRankValidation(t *testing.T) {
|
||||||
|
field := Field("search_vector")
|
||||||
|
|
||||||
|
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
|
||||||
|
for _, alias := range validAliases {
|
||||||
|
t.Run("valid_"+alias, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("TsRank should not panic for valid alias %q", alias)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.TsRank("query", alias)
|
||||||
|
if !strings.Contains(string(result), alias) {
|
||||||
|
t.Errorf("TsRank result should contain alias %q", alias)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidAliases := []string{
|
||||||
|
"'; DROP TABLE",
|
||||||
|
"rank name",
|
||||||
|
"rank-name",
|
||||||
|
"123rank",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range invalidAliases {
|
||||||
|
t.Run("invalid_"+alias, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("TsRank should panic for invalid alias %q", alias)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.TsRank("query", alias)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRowNumberValidation tests ROW_NUMBER alias validation
|
||||||
|
func TestRowNumberValidation(t *testing.T) {
|
||||||
|
field := Field("id")
|
||||||
|
|
||||||
|
t.Run("valid_alias", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("RowNumber should not panic for valid alias: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result := field.RowNumber("row_num")
|
||||||
|
if !strings.Contains(string(result), "row_num") {
|
||||||
|
t.Errorf("RowNumber result should contain alias")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_alias", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("RowNumber should panic for invalid alias")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.RowNumber("'; DROP TABLE")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSQLInjectionAttackVectors tests common SQL injection patterns
|
||||||
|
func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||||
|
attacks := []string{
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"' OR '1'='1",
|
||||||
|
"'; DELETE FROM users WHERE '1'='1",
|
||||||
|
"1'; UPDATE users SET password='hacked'; --",
|
||||||
|
"admin'--",
|
||||||
|
"' UNION SELECT * FROM passwords--",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
|
||||||
|
field := Field("created_at")
|
||||||
|
for _, attack := range attacks {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.DateTrunc("day", attack)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TsRank_alias_protection", func(t *testing.T) {
|
||||||
|
field := Field("search_vector")
|
||||||
|
for _, attack := range attacks {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("TsRank should prevent SQL injection: %q", attack)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
field.TsRank("query", attack)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
|
||||||
|
for _, attack := range attacks {
|
||||||
|
result := ConcatWs(attack, Field("col1"))
|
||||||
|
// Check that quotes are escaped (doubled)
|
||||||
|
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
|
||||||
|
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldName tests Field.Name() method
|
||||||
|
func TestFieldName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field Field
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "field with table prefix",
|
||||||
|
field: Field("users.email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field with schema and table prefix",
|
||||||
|
field: Field("public.users.email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field without dot",
|
||||||
|
field: Field("email"),
|
||||||
|
expected: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field with multiple dots",
|
||||||
|
field: Field("schema.table.column"),
|
||||||
|
expected: "column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "aggregate function",
|
||||||
|
field: Field("COUNT(users.id)"),
|
||||||
|
expected: "id)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.field.Name()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Field.Name() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
86
pgm_table.go
Normal file
86
pgm_table.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package pgm
|
||||||
|
|
||||||
|
// Table in database
|
||||||
|
type Table struct {
|
||||||
|
textSearch *textSearchCTE
|
||||||
|
|
||||||
|
Name string
|
||||||
|
DerivedTable Query
|
||||||
|
PK []string
|
||||||
|
FieldCount uint16
|
||||||
|
debug bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// text search Common Table Expression
|
||||||
|
type textSearchCTE struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
alias string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug when set true will print generated query string in stdout
|
||||||
|
func (t *Table) Debug() Clause {
|
||||||
|
t.debug = true
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Table) Field(f string) Field {
|
||||||
|
return Field(t.Name + "." + f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert table statement
|
||||||
|
func (t *Table) Insert() InsertClause {
|
||||||
|
qb := &insertQry{
|
||||||
|
debug: t.debug,
|
||||||
|
table: t.Name,
|
||||||
|
fields: make([]string, 0, t.FieldCount),
|
||||||
|
vals: make([]string, 0, t.FieldCount),
|
||||||
|
args: make([]any, 0, t.FieldCount),
|
||||||
|
}
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Table) WithTextSearch(name, alias, textToSearch string) *Table {
|
||||||
|
t.textSearch = &textSearchCTE{name: name, value: textToSearch, alias: alias}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select table statement
|
||||||
|
func (t *Table) Select(field ...Field) SelectClause {
|
||||||
|
qb := &selectQry{
|
||||||
|
debug: t.debug,
|
||||||
|
fields: field,
|
||||||
|
textSearch: t.textSearch,
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.DerivedTable != nil {
|
||||||
|
tName, args := t.DerivedTable.Build(true)
|
||||||
|
qb.table = "(" + tName + ") AS " + t.Name
|
||||||
|
qb.args = args
|
||||||
|
} else {
|
||||||
|
qb.table = t.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update table statement
|
||||||
|
func (t *Table) Update() UpdateClause {
|
||||||
|
qb := &updateQry{
|
||||||
|
debug: t.debug,
|
||||||
|
table: t.Name,
|
||||||
|
cols: make([]string, 0, t.FieldCount),
|
||||||
|
args: make([]any, 0, t.FieldCount),
|
||||||
|
}
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete table statement
|
||||||
|
func (t *Table) Delete() DeleteCluase {
|
||||||
|
qb := &deleteQry{
|
||||||
|
debug: t.debug,
|
||||||
|
table: t.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
return qb
|
||||||
|
}
|
||||||
670
pgm_test.go
Normal file
670
pgm_test.go
Normal file
@@ -0,0 +1,670 @@
|
|||||||
|
// Patial Tech.
|
||||||
|
// Author, Ankit Patial
|
||||||
|
|
||||||
|
package pgm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetPoolNotInitialized verifies that GetPool panics when pool is not initialized
|
||||||
|
func TestGetPoolNotInitialized(t *testing.T) {
|
||||||
|
// Save current pool state
|
||||||
|
currentPool := poolPGX.Load()
|
||||||
|
|
||||||
|
// Clear the pool to simulate uninitialized state
|
||||||
|
poolPGX.Store(nil)
|
||||||
|
|
||||||
|
// Restore pool after test
|
||||||
|
defer func() {
|
||||||
|
poolPGX.Store(currentPool)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Verify panic occurs
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Error("GetPool should panic when pool not initialized")
|
||||||
|
} else {
|
||||||
|
// Check panic message
|
||||||
|
msg, ok := r.(string)
|
||||||
|
if !ok || !strings.Contains(msg, "not initialized") {
|
||||||
|
t.Errorf("Expected panic message about initialization, got: %v", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
GetPool() // Should panic
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConfigValidation tests that InitPool validates configuration
|
||||||
|
func TestConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantPanic bool
|
||||||
|
panicMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty connection string",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "",
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "connection string is empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinConns greater than MaxConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MaxConns: 5,
|
||||||
|
MinConns: 10,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "MinConns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative MaxConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MaxConns: -1,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "negative",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative MinConns",
|
||||||
|
config: Config{
|
||||||
|
ConnString: "postgres://localhost/test",
|
||||||
|
MinConns: -1,
|
||||||
|
},
|
||||||
|
wantPanic: true,
|
||||||
|
panicMsg: "negative",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if tt.wantPanic && r == nil {
|
||||||
|
t.Errorf("InitPool should panic for %s", tt.name)
|
||||||
|
}
|
||||||
|
if !tt.wantPanic && r != nil {
|
||||||
|
t.Errorf("InitPool should not panic for %s, got: %v", tt.name, r)
|
||||||
|
}
|
||||||
|
if tt.wantPanic && r != nil {
|
||||||
|
msg := ""
|
||||||
|
switch v := r.(type) {
|
||||||
|
case string:
|
||||||
|
msg = v
|
||||||
|
case error:
|
||||||
|
msg = v.Error()
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, tt.panicMsg) {
|
||||||
|
t.Errorf("Expected panic message to contain %q, got: %q", tt.panicMsg, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
InitPool(tt.config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClosePoolIdempotent verifies ClosePool can be called multiple times safely
|
||||||
|
func TestClosePoolIdempotent(t *testing.T) {
|
||||||
|
// Save current pool
|
||||||
|
currentPool := poolPGX.Load()
|
||||||
|
defer func() {
|
||||||
|
poolPGX.Store(currentPool)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set to nil
|
||||||
|
poolPGX.Store(nil)
|
||||||
|
|
||||||
|
// Should not panic when called on nil pool
|
||||||
|
ClosePool()
|
||||||
|
ClosePool()
|
||||||
|
ClosePool()
|
||||||
|
|
||||||
|
// Should work fine - no panic expected
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsNotFound tests the error checking utility
|
||||||
|
func TestIsNotFound(t *testing.T) {
|
||||||
|
// Test with pgx.ErrNoRows
|
||||||
|
if !IsNotFound(pgx.ErrNoRows) {
|
||||||
|
t.Error("IsNotFound should return true for pgx.ErrNoRows")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with other errors
|
||||||
|
otherErr := errors.New("some other error")
|
||||||
|
if IsNotFound(otherErr) {
|
||||||
|
t.Error("IsNotFound should return false for non-ErrNoRows errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with nil
|
||||||
|
if IsNotFound(nil) {
|
||||||
|
t.Error("IsNotFound should return false for nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgTime tests PostgreSQL timestamp conversion
|
||||||
|
func TestPgTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
pgTime := PgTime(now)
|
||||||
|
|
||||||
|
if !pgTime.Valid {
|
||||||
|
t.Error("PgTime should return valid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pgTime.Time.Equal(now) {
|
||||||
|
t.Errorf("PgTime time mismatch: got %v, want %v", pgTime.Time, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgTimeNow tests current time conversion
|
||||||
|
func TestPgTimeNow(t *testing.T) {
|
||||||
|
before := time.Now()
|
||||||
|
pgTime := PgTimeNow()
|
||||||
|
after := time.Now()
|
||||||
|
|
||||||
|
if !pgTime.Valid {
|
||||||
|
t.Error("PgTimeNow should return valid timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that time is between before and after
|
||||||
|
if pgTime.Time.Before(before) || pgTime.Time.After(after) {
|
||||||
|
t.Error("PgTimeNow should return current time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTsQueryFunctions tests full-text search query builders
|
||||||
|
func TestTsQueryFunctions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(string) string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TsAndQuery",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello & world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsPrefixAndQuery",
|
||||||
|
fn: TsPrefixAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello:* & world:*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsOrQuery",
|
||||||
|
fn: TsOrQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello | world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsPrefixOrQuery",
|
||||||
|
fn: TsPrefixOrQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello:* | world:*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsAndQuery with multiple words",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "one two three",
|
||||||
|
expected: "one & two & three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsOrQuery with multiple words",
|
||||||
|
fn: TsOrQuery,
|
||||||
|
input: "one two three",
|
||||||
|
expected: "one | two | three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TsAndQuery with extra spaces",
|
||||||
|
fn: TsAndQuery,
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello & world",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.fn(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("%s(%q) = %q, want %q", tt.name, tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldsWithSuffix tests the internal suffix function
|
||||||
|
func TestFieldsWithSuffix(t *testing.T) {
|
||||||
|
result := fieldsWithSufix("hello world", ":*")
|
||||||
|
expected := []string{"hello:*", "world:*"}
|
||||||
|
|
||||||
|
if len(result) != len(expected) {
|
||||||
|
t.Fatalf("fieldsWithSufix length mismatch: got %d, want %d", len(result), len(expected))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range result {
|
||||||
|
if v != expected[i] {
|
||||||
|
t.Errorf("fieldsWithSufix[%d] = %q, want %q", i, v, expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBeginTxErrorWrapping tests that BeginTx preserves error context
|
||||||
|
func TestBeginTxErrorWrapping(t *testing.T) {
|
||||||
|
// This test verifies basic behavior without requiring a database connection
|
||||||
|
// Actual error wrapping would require a failing database connection
|
||||||
|
|
||||||
|
// Skip if no pool initialized (unit test environment)
|
||||||
|
if poolPGX.Load() == nil {
|
||||||
|
t.Skip("Skipping BeginTx test - requires initialized pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just verify the function accepts a context
|
||||||
|
// In a real scenario with a cancelled context, it would return an error
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// We can't fully test without a real DB connection, so just document behavior
|
||||||
|
_, err := BeginTx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Error is expected in test environment without valid DB
|
||||||
|
t.Logf("BeginTx returned error (expected in test env): %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContextCancellation tests that cancelled context is handled
|
||||||
|
func TestContextCancellation(t *testing.T) {
|
||||||
|
// Create an already-cancelled context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
// Note: This test would require an actual database connection to verify
|
||||||
|
// that the context cancellation is properly propagated. We're testing
|
||||||
|
// that the context is accepted by the function signature.
|
||||||
|
|
||||||
|
// The actual behavior would be tested in integration tests
|
||||||
|
// For now, we just verify the function doesn't panic with cancelled context
|
||||||
|
|
||||||
|
// Skip if no pool initialized (unit test environment)
|
||||||
|
if poolPGX.Load() == nil {
|
||||||
|
t.Skip("Skipping context cancellation test - requires initialized pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// This would return an error about cancelled context in real scenario
|
||||||
|
_, err := BeginTx(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Log("BeginTx with cancelled context - error expected in real scenario")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestErrorTypes verifies exported error variables
|
||||||
|
func TestErrorTypes(t *testing.T) {
|
||||||
|
if ErrConnStringMissing == nil {
|
||||||
|
t.Error("ErrConnStringMissing should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrInitTX == nil {
|
||||||
|
t.Error("ErrInitTX should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrCommitTX == nil {
|
||||||
|
t.Error("ErrCommitTX should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ErrNoRows == nil {
|
||||||
|
t.Error("ErrNoRows should be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error messages are descriptive
|
||||||
|
if !strings.Contains(ErrConnStringMissing.Error(), "connection string") {
|
||||||
|
t.Error("ErrConnStringMissing should mention connection string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStringPoolGetPut tests the string builder pool
|
||||||
|
func TestStringPoolGetPut(t *testing.T) {
|
||||||
|
// Get a string builder from pool
|
||||||
|
sb1 := getSB()
|
||||||
|
if sb1 == nil {
|
||||||
|
t.Fatal("getSB should return non-nil string builder")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use it
|
||||||
|
sb1.WriteString("test")
|
||||||
|
if sb1.String() != "test" {
|
||||||
|
t.Error("String builder should work normally")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put it back
|
||||||
|
putSB(sb1)
|
||||||
|
|
||||||
|
// Get another one - should be reset
|
||||||
|
sb2 := getSB()
|
||||||
|
if sb2 == nil {
|
||||||
|
t.Fatal("getSB should return non-nil string builder")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be empty after reset
|
||||||
|
if sb2.Len() != 0 {
|
||||||
|
t.Error("String builder from pool should be reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
putSB(sb2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConditionActionTypes tests condition action enum
|
||||||
|
func TestConditionActionTypes(t *testing.T) {
|
||||||
|
// Verify enum values are distinct
|
||||||
|
actions := []CondAction{
|
||||||
|
CondActionNothing,
|
||||||
|
CondActionNeedToClose,
|
||||||
|
CondActionSubQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[CondAction]bool)
|
||||||
|
for _, action := range actions {
|
||||||
|
if seen[action] {
|
||||||
|
t.Errorf("Duplicate CondAction value: %v", action)
|
||||||
|
}
|
||||||
|
seen[action] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectQueryBuilderBasics tests basic query building
|
||||||
|
func TestSelectQueryBuilderBasics(t *testing.T) {
|
||||||
|
// Create a test table
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
emailField := Field("users.email")
|
||||||
|
|
||||||
|
// Test basic select
|
||||||
|
qry := testTable.Select(idField, emailField)
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "SELECT") {
|
||||||
|
t.Error("Query should contain SELECT")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "users.id") {
|
||||||
|
t.Error("Query should contain users.id field")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "users.email") {
|
||||||
|
t.Error("Query should contain users.email field")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "FROM users") {
|
||||||
|
t.Error("Query should contain FROM users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWhereConditionAccumulation tests that WHERE conditions accumulate
|
||||||
|
func TestWhereConditionAccumulation(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
statusField := Field("users.status")
|
||||||
|
|
||||||
|
// Create a query and add multiple WHERE conditions
|
||||||
|
qry := testTable.Select(idField).
|
||||||
|
Where(idField.Eq(1)).
|
||||||
|
Where(statusField.Eq("active"))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
// Both conditions should be present
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have multiple conditions (this demonstrates accumulation)
|
||||||
|
if strings.Count(sql, "$") < 2 {
|
||||||
|
t.Error("Query should have multiple parameterized values")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInsertQueryBuilder tests insert query building
|
||||||
|
func TestInsertQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
|
||||||
|
qry := testTable.Insert()
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "INSERT INTO users") {
|
||||||
|
t.Error("Query should contain INSERT INTO users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateQueryBuilder tests update query building
|
||||||
|
func TestUpdateQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
qry := testTable.Update().
|
||||||
|
Where(idField.Eq(1))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "UPDATE users") {
|
||||||
|
t.Error("Query should contain UPDATE users")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteQueryBuilder tests delete query building
|
||||||
|
func TestDeleteQueryBuilder(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
qry := testTable.Delete().Where(idField.Eq(1))
|
||||||
|
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !strings.Contains(sql, "DELETE FROM users") {
|
||||||
|
t.Error("Query should contain DELETE FROM users")
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "WHERE") {
|
||||||
|
t.Error("Query should contain WHERE clause")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldConditioners tests various field condition methods
|
||||||
|
func TestFieldConditioners(t *testing.T) {
|
||||||
|
field := Field("users.age")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
condition Conditioner
|
||||||
|
checkFunc func(string) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Eq",
|
||||||
|
condition: field.Eq(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " = $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NotEq",
|
||||||
|
condition: field.NotEq(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " != $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gt",
|
||||||
|
condition: field.Gt(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " > $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lt",
|
||||||
|
condition: field.Lt(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " < $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gte",
|
||||||
|
condition: field.Gte(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " >= $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lte",
|
||||||
|
condition: field.Lte(25),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " <= $") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IsNull",
|
||||||
|
condition: field.IsNull(),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " IS NULL") },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IsNotNull",
|
||||||
|
condition: field.IsNotNull(),
|
||||||
|
checkFunc: func(s string) bool { return strings.Contains(s, " IS NOT NULL") },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Build a simple query to see the condition
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
qry := testTable.Select(field).Where(tt.condition)
|
||||||
|
sql := qry.String()
|
||||||
|
|
||||||
|
if !tt.checkFunc(sql) {
|
||||||
|
t.Errorf("%s condition not properly rendered in SQL: %s", tt.name, sql)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFieldFunctions tests SQL function wrappers
|
||||||
|
func TestFieldFunctions(t *testing.T) {
|
||||||
|
field := Field("users.count")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result Field
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Count",
|
||||||
|
result: field.Count(),
|
||||||
|
contains: "COUNT(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Sum",
|
||||||
|
result: field.Sum(),
|
||||||
|
contains: "SUM(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Avg",
|
||||||
|
result: field.Avg(),
|
||||||
|
contains: "AVG(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max",
|
||||||
|
result: field.Max(),
|
||||||
|
contains: "MAX(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Min",
|
||||||
|
result: field.Min(),
|
||||||
|
contains: "Min(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Lower",
|
||||||
|
result: field.Lower(),
|
||||||
|
contains: "LOWER(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Upper",
|
||||||
|
result: field.Upper(),
|
||||||
|
contains: "UPPER(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Trim",
|
||||||
|
result: field.Trim(),
|
||||||
|
contains: "TRIM(",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Asc",
|
||||||
|
result: field.Asc(),
|
||||||
|
contains: " ASC",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Desc",
|
||||||
|
result: field.Desc(),
|
||||||
|
contains: " DESC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resultStr := string(tt.result)
|
||||||
|
if !strings.Contains(resultStr, tt.contains) {
|
||||||
|
t.Errorf("%s should contain %q, got: %s", tt.name, tt.contains, resultStr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInsertQueryValidation tests that Insert queries validate fields are set
|
||||||
|
func TestInsertQueryValidation(t *testing.T) {
|
||||||
|
testTable := &Table{Name: "users", FieldCount: 10}
|
||||||
|
idField := Field("users.id")
|
||||||
|
|
||||||
|
t.Run("Exec without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert()
|
||||||
|
err := qry.Exec(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.Exec() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecTx without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert()
|
||||||
|
// We can't test ExecTx without a real transaction, but we can test the error is defined
|
||||||
|
err := qry.ExecTx(context.Background(), nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.ExecTx() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("First without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert().Returning(idField)
|
||||||
|
var id int
|
||||||
|
err := qry.First(context.Background(), &id)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.First() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FirstTx without fields", func(t *testing.T) {
|
||||||
|
qry := testTable.Insert().Returning(idField)
|
||||||
|
var id int
|
||||||
|
err := qry.FirstTx(context.Background(), nil, &id)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Insert.FirstTx() should return error when no fields set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no fields to insert") {
|
||||||
|
t.Errorf("Expected error about no fields, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package branchuser
|
package branchuser
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table branch_users
|
||||||
|
All pgm.Field = "branch_users.*"
|
||||||
// BranchID field has db type "bigint NOT NULL"
|
// BranchID field has db type "bigint NOT NULL"
|
||||||
BranchID pgm.Field = "branch_users.branch_id"
|
BranchID pgm.Field = "branch_users.branch_id"
|
||||||
// UserID field has db type "bigint NOT NULL"
|
// UserID field has db type "bigint NOT NULL"
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package comment
|
package comment
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table comments
|
||||||
|
All pgm.Field = "comments.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "comments.id"
|
ID pgm.Field = "comments.id"
|
||||||
// PostID field has db type "integer NOT NULL"
|
// PostID field has db type "integer NOT NULL"
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package employee
|
package employee
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table employees
|
||||||
|
All pgm.Field = "employees.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "employees.id"
|
ID pgm.Field = "employees.id"
|
||||||
// Name field has db type "var NOT NULL"
|
// Name field has db type "var NOT NULL"
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package post
|
package post
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table posts
|
||||||
|
All pgm.Field = "posts.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "posts.id"
|
ID pgm.Field = "posts.id"
|
||||||
// UserID field has db type "integer NOT NULL"
|
// UserID field has db type "integer NOT NULL"
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
// Code generated by code.patial.tech/go/pgm/cmd
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
// DO NOT EDIT.
|
|
||||||
|
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
User = pgm.Table{Name: "users", FieldCount: 11}
|
User = pgm.Table{Name: "users", FieldCount: 12}
|
||||||
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
UserSession = pgm.Table{Name: "user_sessions", FieldCount: 8}
|
||||||
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
BranchUser = pgm.Table{Name: "branch_users", FieldCount: 5}
|
||||||
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
Post = pgm.Table{Name: "posts", FieldCount: 5}
|
||||||
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
Comment = pgm.Table{Name: "comments", FieldCount: 5}
|
||||||
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
Employee = pgm.Table{Name: "employees", FieldCount: 5}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func DerivedTable(tblName string, fromQry pgm.Query) pgm.Table {
|
||||||
|
t := pgm.Table{
|
||||||
|
Name: tblName,
|
||||||
|
DerivedTable: fromQry,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package user
|
package user
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table users
|
||||||
|
All pgm.Field = "users.*"
|
||||||
// ID field has db type "integer NOT NULL"
|
// ID field has db type "integer NOT NULL"
|
||||||
ID pgm.Field = "users.id"
|
ID pgm.Field = "users.id"
|
||||||
// Name field has db type "character varying(255) NOT NULL"
|
// Name field has db type "character varying(255) NOT NULL"
|
||||||
@@ -23,6 +25,8 @@ const (
|
|||||||
StatusID pgm.Field = "users.status_id"
|
StatusID pgm.Field = "users.status_id"
|
||||||
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
// MfaKind field has db type "character varying(50) DEFAULT 'None'::character varying"
|
||||||
MfaKind pgm.Field = "users.mfa_kind"
|
MfaKind pgm.Field = "users.mfa_kind"
|
||||||
|
// SearchVector field has db type "tsvector"
|
||||||
|
SearchVector pgm.Field = "users.search_vector"
|
||||||
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
// CreatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||||
CreatedAt pgm.Field = "users.created_at"
|
CreatedAt pgm.Field = "users.created_at"
|
||||||
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
// UpdatedAt field has db type "timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP"
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
// Code generated by db-gen. DO NOT EDIT.
|
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||||
|
|
||||||
package usersession
|
package usersession
|
||||||
|
|
||||||
import "code.patial.tech/go/pgm"
|
import "code.patial.tech/go/pgm"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// All fields in table user_sessions
|
||||||
|
All pgm.Field = "user_sessions.*"
|
||||||
// ID field has db type "character varying NOT NULL"
|
// ID field has db type "character varying NOT NULL"
|
||||||
ID pgm.Field = "user_sessions.id"
|
ID pgm.Field = "user_sessions.id"
|
||||||
// CreatedAt field has db type "timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL"
|
// CreatedAt field has db type "timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL"
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
expected := "DELETE FROM users WHERE users.id = $1 AND users.status_id NOT IN($2)"
|
expected := "DELETE FROM users WHERE users.id = $1 AND users.status_id != ANY($2)"
|
||||||
got := db.User.Delete().
|
got := db.User.Delete().
|
||||||
Where(user.ID.Eq(1), user.StatusID.NotIn(1, 2, 3)).
|
Where(user.ID.Eq(1), user.StatusID.NotAny(1, 2, 3)).
|
||||||
String()
|
String()
|
||||||
if got != expected {
|
if got != expected {
|
||||||
t.Errorf("got %q, want %q", got, expected)
|
t.Errorf("got %q, want %q", got, expected)
|
||||||
|
|||||||
@@ -17,24 +17,7 @@ func TestInsertQuery(t *testing.T) {
|
|||||||
Returning(user.ID).
|
Returning(user.ID).
|
||||||
String()
|
String()
|
||||||
|
|
||||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING users.id"
|
||||||
if got != expected {
|
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInsertSetMap(t *testing.T) {
|
|
||||||
got := db.User.Insert().
|
|
||||||
SetMap(map[pgm.Field]any{
|
|
||||||
user.Email: "aa@aa.com",
|
|
||||||
user.Phone: 8889991234,
|
|
||||||
user.FirstName: "fname",
|
|
||||||
user.LastName: "lname",
|
|
||||||
}).
|
|
||||||
Returning(user.ID).
|
|
||||||
String()
|
|
||||||
|
|
||||||
expected := "INSERT INTO users(email, phone, first_name, last_name) VALUES($1, $2, $3, $4) RETURNING id"
|
|
||||||
if got != expected {
|
if got != expected {
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
}
|
}
|
||||||
@@ -53,7 +36,7 @@ func TestInsertQuery2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkInsertQuery-12 1952412 605.3 ns/op 1144 B/op 18 allocs/op
|
// BenchmarkInsertQuery-12 2517757 459.6 ns/op 1032 B/op 13 allocs/op
|
||||||
func BenchmarkInsertQuery(b *testing.B) {
|
func BenchmarkInsertQuery(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Insert().
|
_ = db.User.Insert().
|
||||||
@@ -67,7 +50,7 @@ func BenchmarkInsertQuery(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkInsertSetMap-12 1534039 777.1 ns/op 1480 B/op 20 allocs/op
|
// BenchmarkInsertSetMap-12 1812555 644.1 ns/op 1368 B/op 15 allocs/op
|
||||||
func BenchmarkInsertSetMap(b *testing.B) {
|
func BenchmarkInsertSetMap(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Insert().
|
_ = db.User.Insert().
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func TestQryBuilder2(t *testing.T) {
|
|||||||
),
|
),
|
||||||
).
|
).
|
||||||
Where(
|
Where(
|
||||||
user.LastName.NEq(7),
|
user.LastName.NotEq(7),
|
||||||
user.Phone.Like("%123%"),
|
user.Phone.Like("%123%"),
|
||||||
user.UpdatedAt.IsNotNull(),
|
user.UpdatedAt.IsNotNull(),
|
||||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||||
@@ -40,7 +40,7 @@ func TestQryBuilder2(t *testing.T) {
|
|||||||
expected := "SELECT users.email, users.first_name FROM users JOIN user_sessions ON users.id = user_sessions.user_id" +
|
expected := "SELECT users.email, users.first_name FROM users JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||||
" JOIN branch_users ON users.id = branch_users.user_id WHERE users.id = $1 AND (users.status_id = $2 OR users.updated_at = $3)" +
|
" JOIN branch_users ON users.id = branch_users.user_id WHERE users.id = $1 AND (users.status_id = $2 OR users.updated_at = $3)" +
|
||||||
" AND users.mfa_kind = $4 AND (users.first_name = $5 OR users.middle_name = $6) AND users.last_name != $7 AND users.phone" +
|
" AND users.mfa_kind = $4 AND (users.first_name = $5 OR users.middle_name = $6) AND users.last_name != $7 AND users.phone" +
|
||||||
" LIKE $8 AND users.updated_at IS NOT NULL AND users.email NOT IN(SELECT users.id FROM users WHERE users.id = $9) LIMIT 10 OFFSET 100"
|
" LIKE $8 AND users.updated_at IS NOT NULL AND users.email != ANY(SELECT users.id FROM users WHERE users.id = $9) LIMIT 10 OFFSET 100"
|
||||||
if expected != got {
|
if expected != got {
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
}
|
}
|
||||||
@@ -60,7 +60,75 @@ func TestSelectWithHaving(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkSelect-12 668817 1753 ns/op 4442 B/op 59 allocs/op
|
func TestSelectWithJoin(t *testing.T) {
|
||||||
|
got := db.User.Select(user.Email, user.FirstName).
|
||||||
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
|
LeftJoin(db.BranchUser, user.ID, branchuser.UserID, pgm.Or(branchuser.RoleID.Eq("1"), branchuser.RoleID.Eq("2"))).
|
||||||
|
Where(
|
||||||
|
user.ID.Eq(3),
|
||||||
|
pgm.Or(
|
||||||
|
user.StatusID.Eq(4),
|
||||||
|
user.UpdatedAt.Eq(5),
|
||||||
|
),
|
||||||
|
).
|
||||||
|
Limit(10).
|
||||||
|
Offset(100).
|
||||||
|
String()
|
||||||
|
|
||||||
|
expected := "SELECT users.email, users.first_name " +
|
||||||
|
"FROM users JOIN user_sessions ON users.id = user_sessions.user_id " +
|
||||||
|
"LEFT JOIN branch_users ON users.id = branch_users.user_id AND (branch_users.role_id = $1 OR branch_users.role_id = $2) " +
|
||||||
|
"WHERE users.id = $3 AND (users.status_id = $4 OR users.updated_at = $5) " +
|
||||||
|
"LIMIT 10 OFFSET 100"
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectDerived(t *testing.T) {
|
||||||
|
expected := "SELECT t.* FROM (SELECT users.*, ROW_NUMBER() OVER (PARTITION BY users.status_id ORDER BY users.created_at DESC) AS rn" +
|
||||||
|
" FROM users WHERE users.status_id = $1) AS t WHERE t.rn <= $2" +
|
||||||
|
" ORDER BY t.status_id, t.created_at DESC"
|
||||||
|
|
||||||
|
qry := db.User.
|
||||||
|
Select(user.All, user.CreatedAt.RowNumberDescPartionBy(user.StatusID, "rn")).
|
||||||
|
Where(user.StatusID.Eq(1))
|
||||||
|
|
||||||
|
tbl := db.DerivedTable("t", qry)
|
||||||
|
got := tbl.
|
||||||
|
Select(tbl.Field("*")).
|
||||||
|
Where(tbl.Field("rn").Lte(5)).
|
||||||
|
OrderBy(tbl.Field("status_id"), tbl.Field("created_at").Desc()).
|
||||||
|
String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectTV(t *testing.T) {
|
||||||
|
expected := "WITH ts AS (SELECT to_tsquery('english', $1) AS query)" +
|
||||||
|
" SELECT users.first_name, users.last_name, users.email, TS_RANK(users.search_vector, ts.query) AS rank" +
|
||||||
|
" FROM users" +
|
||||||
|
" JOIN user_sessions ON users.id = user_sessions.user_id" +
|
||||||
|
" CROSS JOIN ts" +
|
||||||
|
" WHERE users.status_id = $2 AND users.search_vector @@ ts.query" +
|
||||||
|
" ORDER BY rank DESC"
|
||||||
|
|
||||||
|
qry := db.User.
|
||||||
|
WithTextSearch("ts", "query", "text to search").
|
||||||
|
Select(user.FirstName, user.LastName, user.Email, user.SearchVector.TsRank("ts.query", "rank")).
|
||||||
|
Join(db.UserSession, user.ID, usersession.UserID).
|
||||||
|
Where(user.StatusID.Eq(1), user.SearchVector.TsQuery("ts.query")).
|
||||||
|
OrderBy(pgm.Field("rank").Desc())
|
||||||
|
|
||||||
|
got := qry.String()
|
||||||
|
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("\nexpected: %q\n\ngot: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
// BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op
|
||||||
func BenchmarkSelect(b *testing.B) {
|
func BenchmarkSelect(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
@@ -80,7 +148,7 @@ func BenchmarkSelect(b *testing.B) {
|
|||||||
),
|
),
|
||||||
).
|
).
|
||||||
Where(
|
Where(
|
||||||
user.LastName.NEq(7),
|
user.LastName.NotEq(7),
|
||||||
user.Phone.Like("%123%"),
|
user.Phone.Like("%123%"),
|
||||||
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
user.Email.NotInSubQuery(db.User.Select(user.ID).Where(user.ID.Eq(123))),
|
||||||
).
|
).
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package playground
|
package playground
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"code.patial.tech/go/pgm"
|
|
||||||
"code.patial.tech/go/pgm/playground/db"
|
"code.patial.tech/go/pgm/playground/db"
|
||||||
"code.patial.tech/go/pgm/playground/db/user"
|
"code.patial.tech/go/pgm/playground/db/user"
|
||||||
)
|
)
|
||||||
@@ -17,7 +18,7 @@ func TestUpdateQuery(t *testing.T) {
|
|||||||
user.Email.Eq("aa@aa.com"),
|
user.Email.Eq("aa@aa.com"),
|
||||||
).
|
).
|
||||||
Where(
|
Where(
|
||||||
user.StatusID.NEq(1),
|
user.StatusID.NotEq(1),
|
||||||
).
|
).
|
||||||
String()
|
String()
|
||||||
|
|
||||||
@@ -27,28 +28,22 @@ func TestUpdateQuery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateSetMap(t *testing.T) {
|
func TestUpdateQueryValidation(t *testing.T) {
|
||||||
got := db.User.Update().
|
// Test that UPDATE without Set() returns error
|
||||||
SetMap(map[pgm.Field]any{
|
err := db.User.Update().
|
||||||
user.FirstName: "ankit",
|
Where(user.Email.Eq("aa@aa.com")).
|
||||||
user.MiddleName: "singh",
|
Exec(context.Background())
|
||||||
user.LastName: "patial",
|
|
||||||
}).
|
|
||||||
Where(
|
|
||||||
user.Email.Eq("aa@aa.com"),
|
|
||||||
).
|
|
||||||
Where(
|
|
||||||
user.StatusID.NEq(1),
|
|
||||||
).
|
|
||||||
String()
|
|
||||||
|
|
||||||
expected := "UPDATE users SET first_name=$1, middle_name=$2, last_name=$3 WHERE users.email = $4 AND users.status_id != $5"
|
if err == nil {
|
||||||
if got != expected {
|
t.Error("Expected error when calling Exec() without Set(), got nil")
|
||||||
t.Errorf("\nexpected: %q\ngot: %q", expected, got)
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "no columns to update") {
|
||||||
|
t.Errorf("Expected error message to contain 'no columns to update', got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op
|
// BenchmarkUpdateQuery-12 2334889 503.6 ns/op 1112 B/op 17 allocs/op
|
||||||
func BenchmarkUpdateQuery(b *testing.B) {
|
func BenchmarkUpdateQuery(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = db.User.Update().
|
_ = db.User.Update().
|
||||||
@@ -59,7 +54,7 @@ func BenchmarkUpdateQuery(b *testing.B) {
|
|||||||
user.Email.Eq("aa@aa.com"),
|
user.Email.Eq("aa@aa.com"),
|
||||||
).
|
).
|
||||||
Where(
|
Where(
|
||||||
user.StatusID.NEq(1),
|
user.StatusID.NotEq(1),
|
||||||
).
|
).
|
||||||
String()
|
String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ CREATE TABLE public.users (
|
|||||||
last_name character varying(50) NOT NULL,
|
last_name character varying(50) NOT NULL,
|
||||||
status_id smallint,
|
status_id smallint,
|
||||||
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
mfa_kind character varying(50) DEFAULT 'None'::character varying,
|
||||||
|
search_vector tsvector,
|
||||||
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
created_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
updated_at timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|||||||
110
pool.go
110
pool.go
@@ -1,110 +0,0 @@
|
|||||||
// Patial Tech.
|
|
||||||
// Author, Ankit Patial
|
|
||||||
|
|
||||||
package pgm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
poolPGX atomic.Pointer[pgxpool.Pool]
|
|
||||||
poolStringBuilder = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return new(strings.Builder)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ErrConnStringMissing = errors.New("connection string is empty")
|
|
||||||
ErrInitTX = errors.New("failed to init db.tx")
|
|
||||||
ErrCommitTX = errors.New("failed to commit db.tx")
|
|
||||||
ErrNoRows = errors.New("no data found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
ConnString string
|
|
||||||
MaxConns int32
|
|
||||||
MinConns int32
|
|
||||||
MaxConnLifetime time.Duration
|
|
||||||
MaxConnIdleTime time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitPool will create new pgxpool.Pool and will keep it for its working
|
|
||||||
func InitPool(conf Config) {
|
|
||||||
if conf.ConnString == "" {
|
|
||||||
panic(ErrConnStringMissing)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := pgxpool.ParseConfig(conf.ConnString)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MaxConns > 0 {
|
|
||||||
cfg.MaxConns = conf.MaxConns // 100
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MinConns > 0 {
|
|
||||||
cfg.MinConns = conf.MaxConns // 5
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MaxConnLifetime > 0 {
|
|
||||||
cfg.MaxConnLifetime = conf.MaxConnLifetime // time.Minute * 10
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.MaxConnIdleTime > 0 {
|
|
||||||
cfg.MaxConnIdleTime = conf.MaxConnIdleTime // time.Minute * 5
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := pgxpool.NewWithConfig(context.Background(), cfg)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = p.Ping(context.Background()); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
poolPGX.Store(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get string builder from pool
|
|
||||||
func getSB() *strings.Builder {
|
|
||||||
return poolStringBuilder.Get().(*strings.Builder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// put string builder back to pool
|
|
||||||
func putSB(sb *strings.Builder) {
|
|
||||||
sb.Reset()
|
|
||||||
poolStringBuilder.Put(sb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPool instance
|
|
||||||
func GetPool() *pgxpool.Pool {
|
|
||||||
return poolPGX.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BeginTx begins a pgx poll transaction
|
|
||||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
|
||||||
tx, err := poolPGX.Load().Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error(err.Error())
|
|
||||||
return nil, errors.New("failed to open db tx")
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNotFound error check
|
|
||||||
func IsNotFound(err error) bool {
|
|
||||||
return errors.Is(err, pgx.ErrNoRows)
|
|
||||||
}
|
|
||||||
188
qry.go
188
qry.go
@@ -7,156 +7,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Clause interface {
|
Clause interface {
|
||||||
|
Insert() InsertClause
|
||||||
Select(fields ...Field) SelectClause
|
Select(fields ...Field) SelectClause
|
||||||
// Insert() InsertSet
|
Update() UpdateClause
|
||||||
// Update() UpdateSet
|
Delete() DeleteCluase
|
||||||
// Delete() WhereOrExec
|
|
||||||
}
|
|
||||||
|
|
||||||
SelectClause interface {
|
|
||||||
// Join and Inner Join are same
|
|
||||||
Join(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
LeftJoin(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
RightJoin(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
FullJoin(m Table, t1Field, t2Field Field) SelectClause
|
|
||||||
CrossJoin(m Table) SelectClause
|
|
||||||
WhereClause
|
|
||||||
OrderByClause
|
|
||||||
GroupByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
raw(prefixArgs []any) (string, []any)
|
|
||||||
}
|
|
||||||
|
|
||||||
WhereClause interface {
|
|
||||||
Where(cond ...Conditioner) AfterWhere
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterWhere interface {
|
|
||||||
WhereClause
|
|
||||||
GroupByClause
|
|
||||||
OrderByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
GroupByClause interface {
|
|
||||||
GroupBy(fields ...Field) AfterGroupBy
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterGroupBy interface {
|
|
||||||
HavinClause
|
|
||||||
OrderByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
HavinClause interface {
|
|
||||||
Having(cond ...Conditioner) AfterHaving
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterHaving interface {
|
|
||||||
OrderByClause
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
OrderByClause interface {
|
|
||||||
OrderBy(fields ...Field) AfterOrderBy
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterOrderBy interface {
|
|
||||||
LimitClause
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
LimitClause interface {
|
|
||||||
Limit(v int) AfterLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterLimit interface {
|
|
||||||
OffsetClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
OffsetClause interface {
|
|
||||||
Offset(v int) AfterOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterOffset interface {
|
|
||||||
LimitClause
|
|
||||||
Query
|
|
||||||
}
|
|
||||||
|
|
||||||
Conditioner interface {
|
|
||||||
Condition(args *[]any, idx int) string
|
|
||||||
}
|
|
||||||
|
|
||||||
Insert interface {
|
|
||||||
Set(field Field, val any) InsertClause
|
|
||||||
SetMap(fields map[Field]any) InsertClause
|
|
||||||
}
|
|
||||||
|
|
||||||
InsertClause interface {
|
|
||||||
Insert
|
|
||||||
Returning(field Field) First
|
|
||||||
OnConflict(fields ...Field) Do
|
|
||||||
Execute
|
|
||||||
Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
Do interface {
|
|
||||||
DoNothing() Execute
|
|
||||||
DoUpdate(fields ...Field) Execute
|
|
||||||
}
|
|
||||||
|
|
||||||
Update interface {
|
|
||||||
Set(field Field, val any) UpdateClause
|
|
||||||
SetMap(fields map[Field]any) UpdateClause
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateClause interface {
|
|
||||||
Update
|
|
||||||
Where(cond ...Conditioner) WhereOrExec
|
|
||||||
}
|
|
||||||
|
|
||||||
WhereOrExec interface {
|
|
||||||
Where(cond ...Conditioner) WhereOrExec
|
|
||||||
Execute
|
|
||||||
}
|
|
||||||
|
|
||||||
Query interface {
|
|
||||||
First
|
|
||||||
All
|
|
||||||
Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
First interface {
|
|
||||||
First(ctx context.Context, dest ...any) error
|
|
||||||
FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error
|
|
||||||
Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
All interface {
|
|
||||||
// Query rows
|
|
||||||
//
|
|
||||||
// don't forget to close() rows
|
|
||||||
All(ctx context.Context, rows RowsCb) error
|
|
||||||
// Query rows
|
|
||||||
//
|
|
||||||
// don't forget to close() rows
|
|
||||||
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute interface {
|
Execute interface {
|
||||||
@@ -169,26 +30,26 @@ type (
|
|||||||
String() string
|
String() string
|
||||||
}
|
}
|
||||||
|
|
||||||
RowScanner interface {
|
Conditioner interface {
|
||||||
Scan(dest ...any) error
|
Condition(args *[]any, idx int) string
|
||||||
}
|
}
|
||||||
|
|
||||||
RowsCb func(row RowScanner) error
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func joinFileds(fields []Field) string {
|
var sbPool = sync.Pool{
|
||||||
sb := getSB()
|
New: func() any {
|
||||||
defer putSB(sb)
|
return new(strings.Builder)
|
||||||
for i, f := range fields {
|
},
|
||||||
if i == 0 {
|
}
|
||||||
sb.WriteString(f.String())
|
|
||||||
} else {
|
|
||||||
sb.WriteString(", ")
|
|
||||||
sb.WriteString(f.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String()
|
// get string builder from pool
|
||||||
|
func getSB() *strings.Builder {
|
||||||
|
return sbPool.Get().(*strings.Builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// put string builder back to pool
|
||||||
|
func putSB(sb *strings.Builder) {
|
||||||
|
sb.Reset()
|
||||||
|
sbPool.Put(sb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func And(cond ...Conditioner) Conditioner {
|
func And(cond ...Conditioner) Conditioner {
|
||||||
@@ -231,13 +92,14 @@ func (c *CondGroup) Condition(args *[]any, argIdx int) string {
|
|||||||
defer putSB(sb)
|
defer putSB(sb)
|
||||||
|
|
||||||
sb.WriteString("(")
|
sb.WriteString("(")
|
||||||
|
currentIdx := argIdx
|
||||||
for i, cond := range c.cond {
|
for i, cond := range c.cond {
|
||||||
if i == 0 {
|
if i > 0 {
|
||||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
|
||||||
} else {
|
|
||||||
sb.WriteString(c.op)
|
sb.WriteString(c.op)
|
||||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
|
||||||
}
|
}
|
||||||
|
argsLenBefore := len(*args)
|
||||||
|
sb.WriteString(cond.Condition(args, currentIdx))
|
||||||
|
currentIdx += len(*args) - argsLenBefore
|
||||||
}
|
}
|
||||||
sb.WriteString(")")
|
sb.WriteString(")")
|
||||||
return sb.String()
|
return sb.String()
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
DeleteCluase interface {
|
||||||
|
WhereOrExec
|
||||||
|
}
|
||||||
|
|
||||||
deleteQry struct {
|
deleteQry struct {
|
||||||
table string
|
table string
|
||||||
condition []Conditioner
|
condition []Conditioner
|
||||||
@@ -15,14 +19,11 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Table) Delete() WhereOrExec {
|
// WARNING: DELETE without WHERE clause will delete ALL rows in the table.
|
||||||
qb := &deleteQry{
|
// Always use Where() to specify conditions unless you intentionally want to delete all rows.
|
||||||
table: t.Name,
|
// Example:
|
||||||
debug: t.debug,
|
// table.Delete().Where(field.Eq(value)).Exec(ctx) // Safe
|
||||||
}
|
// table.Delete().Exec(ctx) // Deletes ALL rows!
|
||||||
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *deleteQry) Where(cond ...Conditioner) WhereOrExec {
|
func (q *deleteQry) Where(cond ...Conditioner) WhereOrExec {
|
||||||
q.condition = append(q.condition, cond...)
|
q.condition = append(q.condition, cond...)
|
||||||
|
|||||||
@@ -5,36 +5,40 @@ package pgm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type insertQry struct {
|
type (
|
||||||
returing *string
|
InsertClause interface {
|
||||||
onConflict *string
|
Insert
|
||||||
|
Returning(field Field) First
|
||||||
table string
|
OnConflict(fields ...Field) Do
|
||||||
conflictAction string
|
Execute
|
||||||
|
Stringer
|
||||||
fields []string
|
|
||||||
vals []string
|
|
||||||
args []any
|
|
||||||
debug bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Table) Insert() Insert {
|
|
||||||
qb := &insertQry{
|
|
||||||
table: t.Name,
|
|
||||||
fields: make([]string, 0, t.FieldCount),
|
|
||||||
vals: make([]string, 0, t.FieldCount),
|
|
||||||
args: make([]any, 0, t.FieldCount),
|
|
||||||
debug: t.debug,
|
|
||||||
}
|
}
|
||||||
return qb
|
|
||||||
}
|
Insert interface {
|
||||||
|
Set(field Field, val any) InsertClause
|
||||||
|
SetMap(fields map[Field]any) InsertClause
|
||||||
|
}
|
||||||
|
|
||||||
|
insertQry struct {
|
||||||
|
returing *string
|
||||||
|
onConflict *string
|
||||||
|
|
||||||
|
table string
|
||||||
|
conflictAction string
|
||||||
|
|
||||||
|
fields []string
|
||||||
|
vals []string
|
||||||
|
args []any
|
||||||
|
debug bool
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func (q *insertQry) Set(field Field, val any) InsertClause {
|
func (q *insertQry) Set(field Field, val any) InsertClause {
|
||||||
q.fields = append(q.fields, field.Name())
|
q.fields = append(q.fields, field.Name())
|
||||||
@@ -51,7 +55,7 @@ func (q *insertQry) SetMap(cols map[Field]any) InsertClause {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) Returning(field Field) First {
|
func (q *insertQry) Returning(field Field) First {
|
||||||
col := field.Name()
|
col := field.String()
|
||||||
q.returing = &col
|
q.returing = &col
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
@@ -80,21 +84,28 @@ func (q *insertQry) DoNothing() Execute {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
func (q *insertQry) DoUpdate(fields ...Field) Execute {
|
||||||
var sb strings.Builder
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
|
||||||
|
sb.WriteString("DO UPDATE SET ")
|
||||||
for i, f := range fields {
|
for i, f := range fields {
|
||||||
col := f.Name()
|
col := f.Name()
|
||||||
if i == 0 {
|
if i > 0 {
|
||||||
fmt.Fprintf(&sb, "%s = EXCLUDED.%s", col, col)
|
sb.WriteString(", ")
|
||||||
} else {
|
|
||||||
fmt.Fprintf(&sb, ", %s = EXCLUDED.%s", col, col)
|
|
||||||
}
|
}
|
||||||
|
sb.WriteString(col)
|
||||||
|
sb.WriteString(" = EXCLUDED.")
|
||||||
|
sb.WriteString(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
q.conflictAction = "DO UPDATE SET " + sb.String()
|
q.conflictAction = sb.String()
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) Exec(ctx context.Context) error {
|
func (q *insertQry) Exec(ctx context.Context) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before Exec()")
|
||||||
|
}
|
||||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -103,6 +114,9 @@ func (q *insertQry) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before ExecTx()")
|
||||||
|
}
|
||||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -112,10 +126,16 @@ func (q *insertQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
func (q *insertQry) First(ctx context.Context, dest ...any) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before First()")
|
||||||
|
}
|
||||||
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
return poolPGX.Load().QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
func (q *insertQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
||||||
|
if len(q.fields) == 0 {
|
||||||
|
return errors.New("insert query has no fields to insert: call Set() before FirstTx()")
|
||||||
|
}
|
||||||
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
return tx.QueryRow(ctx, q.String(), q.args...).Scan(dest...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
245
qry_select.go
245
qry_select.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -14,7 +15,128 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
SelectClause interface {
|
||||||
|
// Join and Inner Join are same
|
||||||
|
Join(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
LeftJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
RightJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
FullJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause
|
||||||
|
CrossJoin(m Table) SelectClause
|
||||||
|
WhereClause
|
||||||
|
OrderByClause
|
||||||
|
GroupByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
raw(prefixArgs []any) (string, []any)
|
||||||
|
}
|
||||||
|
|
||||||
|
WhereClause interface {
|
||||||
|
Where(cond ...Conditioner) AfterWhere
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterWhere interface {
|
||||||
|
WhereClause
|
||||||
|
GroupByClause
|
||||||
|
OrderByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
GroupByClause interface {
|
||||||
|
GroupBy(fields ...Field) AfterGroupBy
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterGroupBy interface {
|
||||||
|
HavinClause
|
||||||
|
OrderByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
HavinClause interface {
|
||||||
|
Having(cond ...Conditioner) AfterHaving
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterHaving interface {
|
||||||
|
OrderByClause
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
OrderByClause interface {
|
||||||
|
OrderBy(fields ...Field) AfterOrderBy
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterOrderBy interface {
|
||||||
|
LimitClause
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
LimitClause interface {
|
||||||
|
Limit(v int) AfterLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterLimit interface {
|
||||||
|
OffsetClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
OffsetClause interface {
|
||||||
|
Offset(v int) AfterOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
AfterOffset interface {
|
||||||
|
LimitClause
|
||||||
|
Query
|
||||||
|
}
|
||||||
|
|
||||||
|
Do interface {
|
||||||
|
DoNothing() Execute
|
||||||
|
DoUpdate(fields ...Field) Execute
|
||||||
|
}
|
||||||
|
|
||||||
|
Query interface {
|
||||||
|
First
|
||||||
|
All
|
||||||
|
Stringer
|
||||||
|
Bulder
|
||||||
|
}
|
||||||
|
|
||||||
|
RowScanner interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
RowsCb func(row RowScanner) error
|
||||||
|
|
||||||
|
First interface {
|
||||||
|
First(ctx context.Context, dest ...any) error
|
||||||
|
FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error
|
||||||
|
Stringer
|
||||||
|
}
|
||||||
|
|
||||||
|
All interface {
|
||||||
|
// Query rows
|
||||||
|
//
|
||||||
|
// don't forget to close() rows
|
||||||
|
All(ctx context.Context, rows RowsCb) error
|
||||||
|
// Query rows
|
||||||
|
//
|
||||||
|
// don't forget to close() rows
|
||||||
|
AllTx(ctx context.Context, tx pgx.Tx, rows RowsCb) error
|
||||||
|
}
|
||||||
|
|
||||||
|
Bulder interface {
|
||||||
|
Build(needArgs bool) (qry string, args []any)
|
||||||
|
}
|
||||||
|
|
||||||
selectQry struct {
|
selectQry struct {
|
||||||
|
textSearch *textSearchCTE
|
||||||
|
|
||||||
table string
|
table string
|
||||||
fields []Field
|
fields []Field
|
||||||
args []any
|
args []any
|
||||||
@@ -23,9 +145,11 @@ type (
|
|||||||
groupBy []Field
|
groupBy []Field
|
||||||
having []Conditioner
|
having []Conditioner
|
||||||
orderBy []Field
|
orderBy []Field
|
||||||
limit int
|
|
||||||
offset int
|
limit int
|
||||||
debug bool
|
offset int
|
||||||
|
|
||||||
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
CondAction uint8
|
CondAction uint8
|
||||||
@@ -51,34 +175,47 @@ const (
|
|||||||
CondActionSubQuery
|
CondActionSubQuery
|
||||||
)
|
)
|
||||||
|
|
||||||
// Select clause
|
func (q *selectQry) Join(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
func (t Table) Select(field ...Field) SelectClause {
|
return q.buildJoin(t, "JOIN", t1Field, t2Field, cond...)
|
||||||
qb := &selectQry{
|
}
|
||||||
table: t.Name,
|
|
||||||
debug: t.debug,
|
func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
fields: field,
|
return q.buildJoin(t, "LEFT JOIN", t1Field, t2Field, cond...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
|
return q.buildJoin(t, "RIGHT JOIN", t1Field, t2Field, cond...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
|
return q.buildJoin(t, "FULL JOIN", t1Field, t2Field, cond...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, cond ...Conditioner) SelectClause {
|
||||||
|
str := joinKW + " " + t.Name + " ON " + t1Field.String() + " = " + t2Field.String()
|
||||||
|
if len(cond) == 0 { // Join with no condition
|
||||||
|
q.join = append(q.join, str)
|
||||||
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
return qb
|
// Join has condition(s)
|
||||||
}
|
sb := getSB()
|
||||||
|
defer putSB(sb)
|
||||||
|
sb.Grow(len(str) * 2)
|
||||||
|
|
||||||
func (q *selectQry) Join(t Table, t1Field, t2Field Field) SelectClause {
|
sb.WriteString(str)
|
||||||
q.join = append(q.join, "JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
sb.WriteString(" AND ")
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field) SelectClause {
|
var argIdx int
|
||||||
q.join = append(q.join, "LEFT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
for i, c := range cond {
|
||||||
return q
|
argIdx = len(q.args)
|
||||||
}
|
if i > 0 {
|
||||||
|
sb.WriteString(" AND ")
|
||||||
|
}
|
||||||
|
sb.WriteString(c.Condition(&q.args, argIdx))
|
||||||
|
}
|
||||||
|
|
||||||
func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field) SelectClause {
|
q.join = append(q.join, sb.String())
|
||||||
q.join = append(q.join, "RIGHT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field) SelectClause {
|
|
||||||
q.join = append(q.join, "FULL JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String())
|
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,9 +264,13 @@ func (q *selectQry) FirstTx(ctx context.Context, tx pgx.Tx, dest ...any) error {
|
|||||||
|
|
||||||
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
||||||
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
rows, err := poolPGX.Load().Query(ctx, q.String(), q.args...)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if err != nil {
|
||||||
return ErrNoRows
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -138,13 +279,21 @@ func (q *selectQry) All(ctx context.Context, row RowsCb) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for errors from iteration
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
||||||
rows, err := tx.Query(ctx, q.String(), q.args...)
|
rows, err := tx.Query(ctx, q.String(), q.args...)
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if err != nil {
|
||||||
return ErrNoRows
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@@ -154,6 +303,11 @@ func (q *selectQry) AllTx(ctx context.Context, tx pgx.Tx, row RowsCb) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for errors from iteration
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,11 +317,28 @@ func (q *selectQry) raw(prefixArgs []any) (string, []any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) String() string {
|
func (q *selectQry) String() string {
|
||||||
|
qry, _ := q.Build(false)
|
||||||
|
if q.debug {
|
||||||
|
fmt.Println("***")
|
||||||
|
fmt.Println(qry)
|
||||||
|
fmt.Printf("%+v\n", q.args)
|
||||||
|
fmt.Println("***")
|
||||||
|
}
|
||||||
|
return qry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *selectQry) Build(needArgs bool) (qry string, args []any) {
|
||||||
sb := getSB()
|
sb := getSB()
|
||||||
defer putSB(sb)
|
defer putSB(sb)
|
||||||
|
|
||||||
sb.Grow(q.averageLen())
|
sb.Grow(q.averageLen())
|
||||||
|
|
||||||
|
if q.textSearch != nil {
|
||||||
|
var ts = q.textSearch
|
||||||
|
q.args = slices.Insert(q.args, 0, any(ts.value))
|
||||||
|
sb.WriteString("WITH " + ts.name + " AS (SELECT to_tsquery('english', $1) AS " + ts.alias + ") ")
|
||||||
|
}
|
||||||
|
|
||||||
// SELECT
|
// SELECT
|
||||||
sb.WriteString("SELECT ")
|
sb.WriteString("SELECT ")
|
||||||
sb.WriteString(joinFileds(q.fields))
|
sb.WriteString(joinFileds(q.fields))
|
||||||
@@ -178,6 +349,11 @@ func (q *selectQry) String() string {
|
|||||||
sb.WriteString(" " + strings.Join(q.join, " "))
|
sb.WriteString(" " + strings.Join(q.join, " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search Query Cross join
|
||||||
|
if q.textSearch != nil {
|
||||||
|
sb.WriteString(" CROSS JOIN " + q.textSearch.name)
|
||||||
|
}
|
||||||
|
|
||||||
// WHERE
|
// WHERE
|
||||||
if len(q.where) > 0 {
|
if len(q.where) > 0 {
|
||||||
sb.WriteString(" WHERE ")
|
sb.WriteString(" WHERE ")
|
||||||
@@ -228,14 +404,19 @@ func (q *selectQry) String() string {
|
|||||||
sb.WriteString(strconv.Itoa(q.offset))
|
sb.WriteString(strconv.Itoa(q.offset))
|
||||||
}
|
}
|
||||||
|
|
||||||
qry := sb.String()
|
qry = sb.String()
|
||||||
if q.debug {
|
if q.debug {
|
||||||
fmt.Println("***")
|
fmt.Println("***")
|
||||||
fmt.Println(qry)
|
fmt.Println(qry)
|
||||||
fmt.Printf("%+v\n", q.args)
|
fmt.Printf("%+v\n", q.args)
|
||||||
fmt.Println("***")
|
fmt.Println("***")
|
||||||
}
|
}
|
||||||
return qry
|
|
||||||
|
if needArgs {
|
||||||
|
args = slices.Clone(q.args)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *selectQry) averageLen() int {
|
func (q *selectQry) averageLen() int {
|
||||||
|
|||||||
@@ -5,29 +5,37 @@ package pgm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type updateQry struct {
|
type (
|
||||||
table string
|
Update interface {
|
||||||
cols []string
|
Set(field Field, val any) UpdateClause
|
||||||
condition []Conditioner
|
SetMap(fields map[Field]any) UpdateClause
|
||||||
args []any
|
|
||||||
debug bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Table) Update() Update {
|
|
||||||
qb := &updateQry{
|
|
||||||
table: t.Name,
|
|
||||||
debug: t.debug,
|
|
||||||
cols: make([]string, 0, t.FieldCount),
|
|
||||||
args: make([]any, 0, t.FieldCount),
|
|
||||||
}
|
}
|
||||||
return qb
|
|
||||||
}
|
UpdateClause interface {
|
||||||
|
Update
|
||||||
|
Where(cond ...Conditioner) WhereOrExec
|
||||||
|
}
|
||||||
|
|
||||||
|
WhereOrExec interface {
|
||||||
|
Where(cond ...Conditioner) WhereOrExec
|
||||||
|
Execute
|
||||||
|
}
|
||||||
|
|
||||||
|
updateQry struct {
|
||||||
|
table string
|
||||||
|
cols []string
|
||||||
|
condition []Conditioner
|
||||||
|
args []any
|
||||||
|
debug bool
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func (q *updateQry) Set(field Field, val any) UpdateClause {
|
func (q *updateQry) Set(field Field, val any) UpdateClause {
|
||||||
col := field.Name()
|
col := field.Name()
|
||||||
@@ -49,6 +57,9 @@ func (q *updateQry) Where(cond ...Conditioner) WhereOrExec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *updateQry) Exec(ctx context.Context) error {
|
func (q *updateQry) Exec(ctx context.Context) error {
|
||||||
|
if len(q.cols) == 0 {
|
||||||
|
return errors.New("update query has no columns to update: call Set() before Exec()")
|
||||||
|
}
|
||||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -58,6 +69,9 @@ func (q *updateQry) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *updateQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
func (q *updateQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
if len(q.cols) == 0 {
|
||||||
|
return errors.New("update query has no columns to update: call Set() before ExecTx()")
|
||||||
|
}
|
||||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user