Audit with AI
This commit is contained in:
17
Makefile
17
Makefile
@@ -1,4 +1,7 @@
|
||||
.PHONY: run bench-select test
|
||||
.PHONY: run bench-select test build install
|
||||
|
||||
# Version can be set via: make build VERSION=v1.2.3
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
|
||||
run:
|
||||
go run ./cmd -o ./playground/db ./playground/schema.sql
|
||||
@@ -8,3 +11,15 @@ bench-select:
|
||||
|
||||
test:
|
||||
go test ./playground
|
||||
|
||||
# Build with version information
|
||||
build:
|
||||
go build -ldflags "-X main.version=$(VERSION)" -o pgm ./cmd
|
||||
|
||||
# Install to GOPATH/bin with version
|
||||
install:
|
||||
go install -ldflags "-X main.version=$(VERSION)" ./cmd
|
||||
|
||||
# Show current version
|
||||
version:
|
||||
@echo "Version: $(VERSION)"
|
||||
|
||||
720
README.md
720
README.md
@@ -1,81 +1,711 @@
|
||||
# pgm - PostgreSQL Query Mapper
|
||||
|
||||
A lightweight ORM built on top of [jackc/pgx](https://github.com/jackc/pgx) database connection pool.
|
||||
[](https://pkg.go.dev/code.patial.tech/go/pgm)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
## ORMs I Like in the Go Ecosystem
|
||||
A lightweight, type-safe PostgreSQL query builder for Go, built on top of [jackc/pgx](https://github.com/jackc/pgx). **pgm** generates Go code from your SQL schema, enabling you to write SQL queries with compile-time safety and autocompletion support.
|
||||
|
||||
- [ent](https://github.com/ent/ent)
|
||||
- [sqlc](https://github.com/sqlc-dev/sqlc)
|
||||
## Features
|
||||
|
||||
## Why Not Use `ent`?
|
||||
- **Type-safe queries** - Column and table names are validated at compile time
|
||||
- **Zero reflection** - Fast performance with no runtime reflection overhead
|
||||
- **SQL schema-based** - Generate Go code directly from your SQL schema files
|
||||
- **Fluent API** - Intuitive query builder with method chaining
|
||||
- **Transaction support** - First-class support for pgx transactions
|
||||
- **Full-text search** - Built-in PostgreSQL full-text search helpers
|
||||
- **Connection pooling** - Leverages pgx connection pool for optimal performance
|
||||
- **Minimal code generation** - Only generates what you need, no bloat
|
||||
|
||||
`ent` is a feature-rich ORM with schema definition, automatic migrations, integration with `gqlgen` (GraphQL server), and more. It provides nearly everything you could want in an ORM.
|
||||
## Table of Contents
|
||||
|
||||
However, it can be overkill. The generated code supports a wide range of features, many of which you may not use, significantly increasing the compiled binary size.
|
||||
- [Why pgm?](#why-pgm)
|
||||
- [Installation](#installation)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Usage Examples](#usage-examples)
|
||||
- [SELECT Queries](#select-queries)
|
||||
- [INSERT Queries](#insert-queries)
|
||||
- [UPDATE Queries](#update-queries)
|
||||
- [DELETE Queries](#delete-queries)
|
||||
- [Joins](#joins)
|
||||
- [Transactions](#transactions)
|
||||
- [Full-Text Search](#full-text-search)
|
||||
- [CLI Tool](#cli-tool)
|
||||
- [API Documentation](#api-documentation)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
|
||||
## Why Not Use `sqlc`?
|
||||
## Why pgm?
|
||||
|
||||
`sqlc` is a great tool, but it often feels like the database layer introduces its own models. This forces you to either map your application’s models to these database models or use the database models directly, which may not align with your application’s design.
|
||||
### The Problem with Existing ORMs
|
||||
|
||||
## Issues with Existing ORMs
|
||||
While Go has excellent ORMs like [ent](https://github.com/ent/ent) and [sqlc](https://github.com/sqlc-dev/sqlc), they come with tradeoffs:
|
||||
|
||||
Here are some common pain points with ORMs:
|
||||
**ent** - Feature-rich but heavy:
|
||||
- Generates extensive code for features you may never use
|
||||
- Significantly increases binary size
|
||||
- Complex schema definition in Go instead of SQL
|
||||
- Auto-migrations can obscure actual database schema
|
||||
|
||||
- **Auto Migrations**: Many ORMs either lack robust migration support or implement complex methods for simple schema changes. This can obscure the database schema, making it harder to understand and maintain. A database schema should be defined in clear SQL statements that can be tested in a SQL query editor. Tools like [dbmate](https://github.com/amacneil/dbmate) provide a mature solution for managing migrations, usable via CLI or in code.
|
||||
**sqlc** - Great tool, but:
|
||||
- Creates separate database models, forcing model mapping
|
||||
- Query results require their own generated types
|
||||
- Less flexibility in dynamic query building
|
||||
|
||||
- **Excessive Code Generation**: ORMs often generate excessive code for various conditions and scenarios, much of which goes unused.
|
||||
### The pgm Approach
|
||||
|
||||
- **Generated Models for Queries**: Auto-generated models for `SELECT` queries force you to either adopt them or map them to your application’s models, adding complexity.
|
||||
**pgm** takes a hybrid approach:
|
||||
|
||||
## A Hybrid Approach: Plain SQL Queries with `pgm`
|
||||
✅ **Schema as SQL** - Define your database schema in pure SQL, where it belongs
|
||||
✅ **Minimal generation** - Only generates table and column definitions
|
||||
✅ **Your models** - Use your own application models, no forced abstractions
|
||||
✅ **Type safety** - Catch schema changes at compile time
|
||||
✅ **SQL power** - Full control over your queries with a fluent API
|
||||
✅ **Migration-friendly** - Use mature tools like [dbmate](https://github.com/amacneil/dbmate) for migrations
|
||||
|
||||
Plain SQL queries are not inherently bad but come with challenges:
|
||||
|
||||
- **Schema Change Detection**: Changes in the database schema are not easily detected.
|
||||
- **SQL Injection Risks**: Without parameterized queries, SQL injection becomes a concern.
|
||||
|
||||
`pgm` addresses these issues by providing a lightweight CLI tool that generates Go files for your database schema. These files help you write SQL queries while keeping track of schema changes, avoiding hardcoded table and column names.
|
||||
|
||||
## Generating `pgm` Schema Files
|
||||
|
||||
Run the following command to generate schema files:
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go run code.partial.tech/go/pgm/cmd -o ./db ./schema.sql
|
||||
go get code.patial.tech/go/pgm
|
||||
```
|
||||
once you have the schama files created you can use `pgm` as
|
||||
|
||||
Install the CLI tool for schema code generation:
|
||||
|
||||
```bash
|
||||
go install code.patial.tech/go/pgm/cmd@latest
|
||||
```
|
||||
|
||||
### Building from Source
|
||||
|
||||
Build with automatic version detection (uses git tags):
|
||||
|
||||
```bash
|
||||
# Build with version from git tags
|
||||
make build
|
||||
|
||||
# Build with specific version
|
||||
make build VERSION=v1.2.3
|
||||
|
||||
# Install to GOPATH/bin
|
||||
make install
|
||||
|
||||
# Or build manually with version
|
||||
go build -ldflags "-X main.version=v1.2.3" -o pgm ./cmd
|
||||
```
|
||||
|
||||
Check the version:
|
||||
|
||||
```bash
|
||||
pgm -version
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Create Your Schema
|
||||
|
||||
Create a SQL schema file `schema.sql`:
|
||||
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE posts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id),
|
||||
title VARCHAR(500) NOT NULL,
|
||||
content TEXT,
|
||||
published BOOLEAN DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
### 2. Generate Go Code
|
||||
|
||||
Run the pgm CLI tool:
|
||||
|
||||
```bash
|
||||
pgm -o ./db ./schema.sql
|
||||
```
|
||||
|
||||
This generates Go files for each table in `./db/`:
|
||||
- `db/users/users.go` - Table and column definitions for users
|
||||
- `db/posts/posts.go` - Table and column definitions for posts
|
||||
|
||||
### 3. Use in Your Code
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"code.partial.tech/go/pgm"
|
||||
"myapp/db/user" // scham create by pgm/cmd
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"code.patial.tech/go/pgm"
|
||||
"yourapp/db/users"
|
||||
"yourapp/db/posts"
|
||||
)
|
||||
|
||||
type MyModel struct {
|
||||
ID string
|
||||
Email string
|
||||
}
|
||||
|
||||
func main() {
|
||||
println("Initializing pgx connection pool")
|
||||
// Initialize connection pool
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: url,
|
||||
ConnString: "postgres://user:pass@localhost:5432/dbname",
|
||||
MaxConns: 25,
|
||||
MinConns: 5,
|
||||
})
|
||||
defer pgm.ClosePool() // Ensure graceful shutdown
|
||||
|
||||
// Select query to fetch the first record
|
||||
// Assumes the schema is defined in the "db" package with a User table
|
||||
var v MyModel
|
||||
err := db.User.Select(user.ID, user.Email).
|
||||
Where(user.Email.Like("anki%")).
|
||||
First(context.TODO(), &v.ID, &v.Email)
|
||||
ctx := context.Background()
|
||||
|
||||
// Query a user
|
||||
var email string
|
||||
err := users.User.Select(users.Email).
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
First(ctx, &email)
|
||||
if err != nil {
|
||||
println("Error:", err.Error())
|
||||
return
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
println("User email:", v.Email)
|
||||
log.Printf("User email: %s", email)
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### SELECT Queries
|
||||
|
||||
#### Basic Select
|
||||
|
||||
```go
|
||||
var user struct {
|
||||
ID string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||
Where(users.Email.Eq("john@example.com")).
|
||||
First(ctx, &user.ID, &user.Email, &user.Name)
|
||||
```
|
||||
|
||||
#### Select with Multiple Conditions
|
||||
|
||||
```go
|
||||
err := users.User.Select(users.ID, users.Email).
|
||||
Where(
|
||||
users.Email.Like("john%"),
|
||||
users.CreatedAt.Gt(time.Now().AddDate(0, -1, 0)),
|
||||
).
|
||||
OrderBy(users.CreatedAt.Desc()).
|
||||
Limit(10).
|
||||
First(ctx, &user.ID, &user.Email)
|
||||
```
|
||||
|
||||
#### Select All with Callback
|
||||
|
||||
```go
|
||||
var userList []User
|
||||
|
||||
err := users.User.Select(users.ID, users.Email, users.Name).
|
||||
Where(users.Name.Like("J%")).
|
||||
OrderBy(users.Name.Asc()).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var u User
|
||||
if err := row.Scan(&u.ID, &u.Email, &u.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
userList = append(userList, u)
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
#### Pagination
|
||||
|
||||
```go
|
||||
page := 2
|
||||
pageSize := 20
|
||||
|
||||
err := users.User.Select(users.ID, users.Email).
|
||||
OrderBy(users.CreatedAt.Desc()).
|
||||
Limit(pageSize).
|
||||
Offset((page - 1) * pageSize).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
// Process rows
|
||||
})
|
||||
```
|
||||
|
||||
#### Grouping and Having
|
||||
|
||||
```go
|
||||
err := posts.Post.Select(posts.UserID, pgm.Count(posts.ID)).
|
||||
GroupBy(posts.UserID).
|
||||
Having(pgm.Count(posts.ID).Gt(5)).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var userID string
|
||||
var postCount int
|
||||
return row.Scan(&userID, &postCount)
|
||||
})
|
||||
```
|
||||
|
||||
### INSERT Queries
|
||||
|
||||
#### Simple Insert
|
||||
|
||||
```go
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
Set(users.CreatedAt, pgm.PgTimeNow()).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Insert with Map
|
||||
|
||||
```go
|
||||
data := map[pgm.Field]any{
|
||||
users.Email: "jane@example.com",
|
||||
users.Name: "Jane Doe",
|
||||
users.CreatedAt: pgm.PgTimeNow(),
|
||||
}
|
||||
|
||||
err := users.User.Insert().
|
||||
SetMap(data).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Insert with RETURNING
|
||||
|
||||
```go
|
||||
var newID string
|
||||
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
Returning(users.ID).
|
||||
First(ctx, &newID)
|
||||
```
|
||||
|
||||
#### Upsert (INSERT ... ON CONFLICT)
|
||||
|
||||
```go
|
||||
// Do nothing on conflict
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
OnConflict(users.Email).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
|
||||
// Update on conflict
|
||||
err := users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe Updated").
|
||||
OnConflict(users.Email).
|
||||
DoUpdate(users.Name).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### UPDATE Queries
|
||||
|
||||
#### Simple Update
|
||||
|
||||
```go
|
||||
err := users.User.Update().
|
||||
Set(users.Name, "John Smith").
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Update Multiple Fields
|
||||
|
||||
```go
|
||||
updates := map[pgm.Field]any{
|
||||
users.Name: "John Smith",
|
||||
users.Email: "john.smith@example.com",
|
||||
}
|
||||
|
||||
err := users.User.Update().
|
||||
SetMap(updates).
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Conditional Update
|
||||
|
||||
```go
|
||||
err := users.User.Update().
|
||||
Set(users.Name, "Updated Name").
|
||||
Where(
|
||||
users.Email.Like("%@example.com"),
|
||||
users.CreatedAt.Lt(time.Now().AddDate(-1, 0, 0)),
|
||||
).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### DELETE Queries
|
||||
|
||||
#### Simple Delete
|
||||
|
||||
```go
|
||||
err := users.User.Delete().
|
||||
Where(users.ID.Eq("some-uuid")).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
#### Conditional Delete
|
||||
|
||||
```go
|
||||
err := posts.Post.Delete().
|
||||
Where(
|
||||
posts.Published.Eq(false),
|
||||
posts.CreatedAt.Lt(time.Now().AddDate(0, 0, -30)),
|
||||
).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### Joins
|
||||
|
||||
#### Inner Join
|
||||
|
||||
```go
|
||||
err := posts.Post.Select(posts.Title, users.Name).
|
||||
Join(users.User, posts.UserID, users.ID).
|
||||
Where(users.Email.Eq("john@example.com")).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var title, userName string
|
||||
return row.Scan(&title, &userName)
|
||||
})
|
||||
```
|
||||
|
||||
#### Left Join
|
||||
|
||||
```go
|
||||
err := users.User.Select(users.Name, posts.Title).
|
||||
LeftJoin(posts.Post, users.ID, posts.UserID).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
var userName, postTitle string
|
||||
return row.Scan(&userName, &postTitle)
|
||||
})
|
||||
```
|
||||
|
||||
#### Join with Additional Conditions
|
||||
|
||||
```go
|
||||
err := posts.Post.Select(posts.Title, users.Name).
|
||||
Join(users.User, posts.UserID, users.ID, users.Email.Like("%@example.com")).
|
||||
Where(posts.Published.Eq(true)).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
// Process rows
|
||||
})
|
||||
```
|
||||
|
||||
### Transactions
|
||||
|
||||
#### Basic Transaction
|
||||
|
||||
```go
|
||||
tx, err := pgm.BeginTx(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Insert user
|
||||
var userID string
|
||||
err = users.User.Insert().
|
||||
Set(users.Email, "jane@example.com").
|
||||
Set(users.Name, "Jane Doe").
|
||||
Returning(users.ID).
|
||||
FirstTx(ctx, tx, &userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert post
|
||||
err = posts.Post.Insert().
|
||||
Set(posts.UserID, userID).
|
||||
Set(posts.Title, "My First Post").
|
||||
Set(posts.Content, "Hello, World!").
|
||||
ExecTx(ctx, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
PostgreSQL full-text search helpers:
|
||||
|
||||
```go
|
||||
// Search with AND operator (all terms must match)
|
||||
searchQuery := pgm.TsAndQuery("golang database")
|
||||
// Result: "golang & database"
|
||||
|
||||
// Search with prefix matching
|
||||
searchQuery := pgm.TsPrefixAndQuery("gol data")
|
||||
// Result: "gol:* & data:*"
|
||||
|
||||
// Search with OR operator (any term matches)
|
||||
searchQuery := pgm.TsOrQuery("golang rust")
|
||||
// Result: "golang | rust"
|
||||
|
||||
// Prefix OR search
|
||||
searchQuery := pgm.TsPrefixOrQuery("go ru")
|
||||
// Result: "go:* | ru:*"
|
||||
|
||||
// Use in query (assuming you have a tsvector column)
|
||||
err := posts.Post.Select(posts.Title, posts.Content).
|
||||
Where(posts.SearchVector.Match(pgm.TsPrefixAndQuery(searchTerm))).
|
||||
OrderBy(posts.CreatedAt.Desc()).
|
||||
All(ctx, func(row pgm.RowScanner) error {
|
||||
// Process results
|
||||
})
|
||||
```
|
||||
|
||||
## CLI Tool
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
pgm -o <output_directory> <schema.sql>
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
```bash
|
||||
-o string Output directory path (required)
|
||||
-version Show version information
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
```bash
|
||||
# Generate from a single schema file
|
||||
pgm -o ./db ./schema.sql
|
||||
|
||||
# Generate from concatenated migrations
|
||||
cat migrations/*.sql > /tmp/schema.sql && pgm -o ./db /tmp/schema.sql
|
||||
|
||||
# Check version
|
||||
pgm -version
|
||||
```
|
||||
|
||||
### Known Limitations
|
||||
|
||||
The CLI tool uses a regex-based SQL parser with the following limitations:
|
||||
|
||||
- ❌ Multi-line comments `/* */` are not supported
|
||||
- ❌ Complex data types (arrays, JSON, JSONB) may not parse correctly
|
||||
- ❌ Quoted identifiers with special characters may fail
|
||||
- ❌ Advanced PostgreSQL features (PARTITION BY, INHERITS) not supported
|
||||
- ❌ Some constraints (CHECK, EXCLUDE) are not parsed
|
||||
|
||||
**Workarounds:**
|
||||
- Use simple CREATE TABLE statements
|
||||
- Avoid complex PostgreSQL-specific syntax in schema files
|
||||
- Split complex schemas into multiple simple statements
|
||||
- Remove comments before running the generator
|
||||
|
||||
For complex schemas, consider contributing a more robust parser or using a proper SQL parser library.
|
||||
|
||||
### Generated Code Structure
|
||||
|
||||
For a table named `users`, pgm generates:
|
||||
|
||||
```
|
||||
db/
|
||||
└── users/
|
||||
└── users.go
|
||||
```
|
||||
|
||||
The generated file contains:
|
||||
- Generated code header with version and timestamp
|
||||
- Table definition (`User`)
|
||||
- Column field definitions (`ID`, `Email`, `Name`, etc.)
|
||||
- Type-safe query builders (`Select()`, `Insert()`, `Update()`, `Delete()`)
|
||||
|
||||
**Example header:**
|
||||
```go
|
||||
// Code generated by code.patial.tech/go/pgm/cmd v1.2.3 on 2025-01-27 15:04:05 DO NOT EDIT.
|
||||
```
|
||||
|
||||
The version in generated files helps track which version of the CLI tool was used, making it easier to identify when regeneration is needed after upgrades.
|
||||
|
||||
## API Documentation
|
||||
|
||||
### Connection Pool
|
||||
|
||||
#### InitPool
|
||||
|
||||
Initialize the connection pool (must be called once at startup):
|
||||
|
||||
```go
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: "postgres://...",
|
||||
MaxConns: 25,
|
||||
MinConns: 5,
|
||||
MaxConnLifetime: time.Hour,
|
||||
MaxConnIdleTime: time.Minute * 30,
|
||||
})
|
||||
```
|
||||
|
||||
**Configuration Validation:**
|
||||
- MinConns cannot be greater than MaxConns
|
||||
- Connection counts cannot be negative
|
||||
- Connection string is required
|
||||
|
||||
#### ClosePool
|
||||
|
||||
Close the connection pool gracefully (call during application shutdown):
|
||||
|
||||
```go
|
||||
func main() {
|
||||
pgm.InitPool(pgm.Config{
|
||||
ConnString: "postgres://...",
|
||||
})
|
||||
defer pgm.ClosePool() // Ensures proper cleanup
|
||||
|
||||
// Your application code
|
||||
}
|
||||
```
|
||||
|
||||
#### GetPool
|
||||
|
||||
Get the underlying pgx pool:
|
||||
|
||||
```go
|
||||
pool := pgm.GetPool()
|
||||
```
|
||||
|
||||
### Query Conditions
|
||||
|
||||
Available condition methods on fields:
|
||||
|
||||
- `Eq(value)` - Equal to
|
||||
- `NotEq(value)` - Not equal to
|
||||
- `Gt(value)` - Greater than
|
||||
- `Gte(value)` - Greater than or equal to
|
||||
- `Lt(value)` - Less than
|
||||
- `Lte(value)` - Less than or equal to
|
||||
- `Like(pattern)` - LIKE pattern match
|
||||
- `ILike(pattern)` - Case-insensitive LIKE
|
||||
- `In(values...)` - IN list
|
||||
- `NotIn(values...)` - NOT IN list
|
||||
- `IsNull()` - IS NULL
|
||||
- `IsNotNull()` - IS NOT NULL
|
||||
- `Between(start, end)` - BETWEEN range
|
||||
|
||||
### Field Methods
|
||||
|
||||
- `Asc()` - Sort ascending
|
||||
- `Desc()` - Sort descending
|
||||
- `Name()` - Get column name
|
||||
- `String()` - Get fully qualified name (table.column)
|
||||
|
||||
### Utilities
|
||||
|
||||
- `pgm.PgTime(t time.Time)` - Convert Go time to PostgreSQL timestamptz
|
||||
- `pgm.PgTimeNow()` - Current time as PostgreSQL timestamptz
|
||||
- `pgm.IsNotFound(err)` - Check if error is "no rows found"
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use transactions for related operations** - Ensure data consistency
|
||||
2. **Define schema in SQL** - Use migration tools like dbmate for schema management
|
||||
3. **Regenerate after schema changes** - Run the CLI tool after any schema modifications
|
||||
4. **Use your own models** - Don't let the database dictate your domain models
|
||||
5. **Handle pgx.ErrNoRows** - Use `pgm.IsNotFound(err)` for cleaner error checking
|
||||
6. **Always use context with timeouts** - Prevent queries from running indefinitely
|
||||
7. **Validate UPDATE queries** - Ensure Set() is called before Exec()
|
||||
8. **Be careful with DELETE** - Always use Where() unless you want to delete all rows
|
||||
|
||||
## Important Safety Notes
|
||||
|
||||
### ⚠️ DELETE Operations
|
||||
|
||||
DELETE without WHERE clause will delete ALL rows in the table:
|
||||
|
||||
```go
|
||||
// ❌ DANGEROUS - Deletes ALL rows!
|
||||
users.User.Delete().Exec(ctx)
|
||||
|
||||
// ✅ SAFE - Deletes specific rows
|
||||
users.User.Delete().Where(users.ID.Eq("user-id")).Exec(ctx)
|
||||
```
|
||||
|
||||
### ⚠️ UPDATE Operations
|
||||
|
||||
UPDATE requires at least one Set() call:
|
||||
|
||||
```go
|
||||
// ❌ ERROR - No columns to update
|
||||
users.User.Update().Where(users.ID.Eq(1)).Exec(ctx)
|
||||
// Returns: "update query has no columns to update"
|
||||
|
||||
// ✅ CORRECT
|
||||
users.User.Update().
|
||||
Set(users.Name, "New Name").
|
||||
Where(users.ID.Eq(1)).
|
||||
Exec(ctx)
|
||||
```
|
||||
|
||||
### ⚠️ Query Timeouts
|
||||
|
||||
Always use context with timeout to prevent hanging queries:
|
||||
|
||||
```go
|
||||
// ✅ RECOMMENDED
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := users.User.Select(users.Email).
|
||||
Where(users.ID.Eq("some-id")).
|
||||
First(ctx, &email)
|
||||
```
|
||||
|
||||
### ⚠️ Connection String Security
|
||||
|
||||
Never log or expose database connection strings as they contain credentials. The library does not sanitize connection strings in error messages.
|
||||
|
||||
## Performance
|
||||
|
||||
**pgm** is designed for performance:
|
||||
|
||||
- Zero reflection overhead
|
||||
- Efficient string building with sync.Pool
|
||||
- Leverages pgx's high-performance connection pooling
|
||||
- Minimal allocations in query building
|
||||
- Direct scanning into your types
|
||||
|
||||
## Requirements
|
||||
|
||||
- Go 1.20 or higher
|
||||
- PostgreSQL 12 or higher
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Author
|
||||
|
||||
|
||||
**Ankit Patial** - [Patial Tech](https://code.patial.tech)
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Built on top of the excellent [jackc/pgx](https://github.com/jackc/pgx) library.
|
||||
|
||||
---
|
||||
|
||||
**Made with ❤️ for the Go community**
|
||||
|
||||
@@ -9,11 +9,16 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// version can be set at build time using:
|
||||
// go build -ldflags "-X main.version=v1.2.3" ./cmd
|
||||
var version = "dev"
|
||||
|
||||
func generate(scheamPath, outDir string) error {
|
||||
// read schame.sql
|
||||
f, err := os.ReadFile(scheamPath)
|
||||
@@ -29,14 +34,15 @@ func generate(scheamPath, outDir string) error {
|
||||
|
||||
// Output dir, create if not exists.
|
||||
if _, err := os.Stat(outDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(outDir, 0740); err != nil {
|
||||
if err := os.MkdirAll(outDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// schema.go will hold all tables info
|
||||
var sb strings.Builder
|
||||
sb.WriteString("// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.\n\n")
|
||||
sb.WriteString(fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s on %s DO NOT EDIT.\n\n",
|
||||
GetVersionString(), time.Now().Format("2006-01-02 15:04:05")))
|
||||
sb.WriteString(fmt.Sprintf("package %s \n", filepath.Base(outDir)))
|
||||
sb.WriteString(`
|
||||
import "code.patial.tech/go/pgm"
|
||||
@@ -70,7 +76,7 @@ func generate(scheamPath, outDir string) error {
|
||||
sb.WriteString("}\n")
|
||||
}
|
||||
modalDir = strings.ToLower(name)
|
||||
os.Mkdir(filepath.Join(outDir, modalDir), 0740)
|
||||
os.Mkdir(filepath.Join(outDir, modalDir), 0755)
|
||||
|
||||
if err = writeColFile(t.Table, t.Columns, filepath.Join(outDir, modalDir), caser); err != nil {
|
||||
return err
|
||||
@@ -95,13 +101,14 @@ func generate(scheamPath, outDir string) error {
|
||||
}
|
||||
|
||||
// Save file to disk
|
||||
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0640)
|
||||
os.WriteFile(filepath.Join(outDir, "schema.go"), code, 0644)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Caser) error {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.\n\n")
|
||||
sb.WriteString(fmt.Sprintf("// Code generated by code.patial.tech/go/pgm/cmd %s on %s DO NOT EDIT.\n\n",
|
||||
GetVersionString(), time.Now().Format("2006-01-02 15:04:05")))
|
||||
sb.WriteString(fmt.Sprintf("package %s\n\n", filepath.Base(outDir)))
|
||||
sb.WriteString(fmt.Sprintf("import %q\n\n", "code.patial.tech/go/pgm"))
|
||||
sb.WriteString("const (")
|
||||
@@ -129,7 +136,7 @@ func writeColFile(tblName string, cols []*Column, outDir string, caser cases.Cas
|
||||
return err
|
||||
}
|
||||
// Save file to disk.
|
||||
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0640)
|
||||
return os.WriteFile(filepath.Join(outDir, tblName+".go"), code, 0644)
|
||||
}
|
||||
|
||||
// pluralToSingular converts plural table names to singular forms
|
||||
|
||||
21
cmd/main.go
21
cmd/main.go
@@ -9,29 +9,42 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
const usageTxt = `Please provide output director and input schema.
|
||||
var (
|
||||
showVersion bool
|
||||
)
|
||||
|
||||
const usageTxt = `Please provide output directory and input schema.
|
||||
Example:
|
||||
pgm/cmd -o ./db ./db/schema.sql
|
||||
pgm -o ./db ./schema.sql
|
||||
|
||||
`
|
||||
|
||||
func main() {
|
||||
var outDir string
|
||||
flag.StringVar(&outDir, "o", "", "-o as output directory path")
|
||||
flag.BoolVar(&showVersion, "version", false, "show version information")
|
||||
flag.Parse()
|
||||
|
||||
// Handle version flag
|
||||
if showVersion {
|
||||
fmt.Printf("pgm %s\n", GetVersionString())
|
||||
fmt.Println("PostgreSQL Query Mapper - Schema code generator")
|
||||
fmt.Println("https://code.patial.tech/go/pgm")
|
||||
return
|
||||
}
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Print(usageTxt)
|
||||
return
|
||||
}
|
||||
|
||||
if outDir == "" {
|
||||
println("missing, -o output directory path")
|
||||
fmt.Fprintln(os.Stderr, "Error: missing output directory path (-o flag required)")
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := generate(os.Args[3], outDir); err != nil {
|
||||
println(err.Error())
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
15
cmd/parse.go
15
cmd/parse.go
@@ -1,5 +1,20 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
//
|
||||
// SQL Parser Limitations:
|
||||
// This is a simple regex-based SQL parser with the following known limitations:
|
||||
// - No support for multi-line comments /* */
|
||||
// - May struggle with complex data types (e.g., arrays, JSON, JSONB)
|
||||
// - No handling of quoted identifiers with special characters
|
||||
// - Advanced features like PARTITION BY, INHERITS not supported
|
||||
// - Does not handle all PostgreSQL-specific syntax
|
||||
// - Constraints like CHECK, EXCLUDE not parsed
|
||||
//
|
||||
// For complex schemas, consider:
|
||||
// 1. Simplifying your schema for generation
|
||||
// 2. Using multiple simple CREATE TABLE statements
|
||||
// 3. Contributing a more robust parser implementation
|
||||
// 4. Using a proper SQL parser library
|
||||
|
||||
package main
|
||||
|
||||
|
||||
67
cmd/version.go
Normal file
67
cmd/version.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Patial Tech.
|
||||
// Author, Ankit Patial
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Version returns the version of the pgm CLI tool.
|
||||
// It tries to detect the version in the following order:
|
||||
// 1. Build-time ldflags (set via: go build -ldflags "-X main.version=v1.2.3")
|
||||
// 2. VCS information from build metadata (git tag/commit)
|
||||
// 3. Falls back to "dev" if no version information is available
|
||||
func Version() string {
|
||||
// If version was set at build time via ldflags
|
||||
if version != "" && version != "dev" {
|
||||
return version
|
||||
}
|
||||
|
||||
// Try to get version from build info (Go 1.18+)
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
// Check for version in main module
|
||||
if info.Main.Version != "" && info.Main.Version != "(devel)" {
|
||||
return info.Main.Version
|
||||
}
|
||||
|
||||
// Try to extract from VCS information
|
||||
var revision, modified string
|
||||
for _, setting := range info.Settings {
|
||||
switch setting.Key {
|
||||
case "vcs.revision":
|
||||
revision = setting.Value
|
||||
case "vcs.modified":
|
||||
modified = setting.Value
|
||||
}
|
||||
}
|
||||
|
||||
// If we have a git revision
|
||||
if revision != "" {
|
||||
// Shorten commit hash to 7 characters
|
||||
if len(revision) > 7 {
|
||||
revision = revision[:7]
|
||||
}
|
||||
|
||||
// Add -dirty suffix if modified
|
||||
if modified == "true" {
|
||||
return "dev-" + revision + "-dirty"
|
||||
}
|
||||
return "dev-" + revision
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to dev
|
||||
return "dev"
|
||||
}
|
||||
|
||||
// GetVersionString returns a formatted version string for display
|
||||
func GetVersionString() string {
|
||||
v := Version()
|
||||
|
||||
// Clean up version string for display
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
|
||||
return "v" + v
|
||||
}
|
||||
27
pgm.go
27
pgm.go
@@ -9,6 +9,7 @@ package pgm
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -45,6 +46,15 @@ func InitPool(conf Config) {
|
||||
panic(ErrConnStringMissing)
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if conf.MaxConns > 0 && conf.MinConns > 0 && conf.MinConns > conf.MaxConns {
|
||||
panic(fmt.Errorf("MinConns (%d) cannot be greater than MaxConns (%d)", conf.MinConns, conf.MaxConns))
|
||||
}
|
||||
|
||||
if conf.MaxConns < 0 || conf.MinConns < 0 {
|
||||
panic(errors.New("connection pool configuration cannot have negative values"))
|
||||
}
|
||||
|
||||
cfg, err := pgxpool.ParseConfig(conf.ConnString)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -55,7 +65,7 @@ func InitPool(conf Config) {
|
||||
}
|
||||
|
||||
if conf.MinConns > 0 {
|
||||
cfg.MinConns = conf.MaxConns // 5
|
||||
cfg.MinConns = conf.MinConns // 5
|
||||
}
|
||||
|
||||
if conf.MaxConnLifetime > 0 {
|
||||
@@ -83,15 +93,24 @@ func GetPool() *pgxpool.Pool {
|
||||
return poolPGX.Load()
|
||||
}
|
||||
|
||||
// ClosePool closes the connection pool gracefully.
|
||||
// Should be called during application shutdown.
|
||||
func ClosePool() {
|
||||
if p := poolPGX.Load(); p != nil {
|
||||
p.Close()
|
||||
poolPGX.Store(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// BeginTx begins a pgx poll transaction
|
||||
func BeginTx(ctx context.Context) (pgx.Tx, error) {
|
||||
tx, err := poolPGX.Load().Begin(ctx)
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
return nil, errors.New("failed to open db tx")
|
||||
slog.Error("failed to begin transaction", "error", err)
|
||||
return nil, fmt.Errorf("failed to open db tx: %w", err)
|
||||
}
|
||||
|
||||
return tx, err
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
// IsNotFound error check
|
||||
|
||||
82
pgm_field.go
82
pgm_field.go
@@ -1,10 +1,49 @@
|
||||
package pgm
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Field related to a table
|
||||
type Field string
|
||||
|
||||
var (
|
||||
// sqlIdentifierRegex validates SQL identifiers (alphanumeric and underscore only)
|
||||
sqlIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// validDateTruncLevels contains all allowed DATE_TRUNC precision levels
|
||||
validDateTruncLevels = map[string]bool{
|
||||
"microseconds": true,
|
||||
"milliseconds": true,
|
||||
"second": true,
|
||||
"minute": true,
|
||||
"hour": true,
|
||||
"day": true,
|
||||
"week": true,
|
||||
"month": true,
|
||||
"quarter": true,
|
||||
"year": true,
|
||||
"decade": true,
|
||||
"century": true,
|
||||
"millennium": true,
|
||||
}
|
||||
)
|
||||
|
||||
// validateSQLIdentifier checks if a string is a valid SQL identifier
|
||||
func validateSQLIdentifier(s string) error {
|
||||
if !sqlIdentifierRegex.MatchString(s) {
|
||||
return fmt.Errorf("invalid SQL identifier: %q (must be alphanumeric and underscore only)", s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// escapeSQLString escapes single quotes in a string for SQL
|
||||
func escapeSQLString(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
func (f Field) Name() string {
|
||||
return strings.Split(string(f), ".")[1]
|
||||
}
|
||||
@@ -18,16 +57,28 @@ func (f Field) Count() Field {
|
||||
return Field("COUNT(" + f.String() + ")")
|
||||
}
|
||||
|
||||
// ConcatWs creates a CONCAT_WS SQL function.
|
||||
// SECURITY: The sep parameter should only be a constant string, not user input.
|
||||
// Single quotes in sep will be escaped automatically.
|
||||
func ConcatWs(sep string, fields ...Field) Field {
|
||||
return Field("concat_ws('" + sep + "'," + joinFileds(fields) + ")")
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("concat_ws('" + escapedSep + "'," + joinFileds(fields) + ")")
|
||||
}
|
||||
|
||||
// StringAgg creates a STRING_AGG SQL function.
|
||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAgg(exp, sep string) Field {
|
||||
return Field("string_agg(" + exp + ",'" + sep + "')")
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(" + exp + ",'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringAggCast creates a STRING_AGG SQL function with cast to varchar.
|
||||
// SECURITY: The exp parameter must be a valid field/column name, not arbitrary SQL.
|
||||
// The sep parameter should only be a constant string. Single quotes will be escaped.
|
||||
func StringAggCast(exp, sep string) Field {
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + sep + "')")
|
||||
escapedSep := escapeSQLString(sep)
|
||||
return Field("string_agg(cast(" + exp + " as varchar),'" + escapedSep + "')")
|
||||
}
|
||||
|
||||
// StringEscape will wrap field with:
|
||||
@@ -114,6 +165,12 @@ func (f Field) RowNumberDescPartionBy(partition Field, as string) Field {
|
||||
}
|
||||
|
||||
func rowNumber(f, partition *Field, isAsc bool, as string) Field {
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if as != "" {
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in rowNumber: %v", err))
|
||||
}
|
||||
}
|
||||
var orderBy string
|
||||
if isAsc {
|
||||
orderBy = " ASC"
|
||||
@@ -150,10 +207,25 @@ func (f Field) IsNotNull() Conditioner {
|
||||
// - day, week (Monday start), month, quarter, year
|
||||
// - decade, century, millennium
|
||||
func (f Field) DateTrunc(level, as string) Field {
|
||||
return Field("DATE_TRUNC('" + level + "', " + f.String() + ") AS " + as)
|
||||
// Validate level parameter against allowed values
|
||||
if !validDateTruncLevels[strings.ToLower(level)] {
|
||||
panic(fmt.Sprintf("invalid DATE_TRUNC level: %q (allowed: microseconds, milliseconds, second, minute, hour, day, week, month, quarter, year, decade, century, millennium)", level))
|
||||
}
|
||||
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in DateTrunc: %v", err))
|
||||
}
|
||||
|
||||
return Field("DATE_TRUNC('" + strings.ToLower(level) + "', " + f.String() + ") AS " + as)
|
||||
}
|
||||
|
||||
func (f Field) TsRank(fieldName, as string) Field {
|
||||
// Validate as parameter is a valid SQL identifier
|
||||
if err := validateSQLIdentifier(as); err != nil {
|
||||
panic(fmt.Sprintf("invalid AS alias in TsRank: %v", err))
|
||||
}
|
||||
|
||||
return Field("TS_RANK(" + f.String() + ", " + fieldName + ") AS " + as)
|
||||
}
|
||||
|
||||
|
||||
338
pgm_field_test.go
Normal file
338
pgm_field_test.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package pgm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestValidateSQLIdentifier tests SQL identifier validation
|
||||
func TestValidateSQLIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid simple", "column_name", false},
|
||||
{"valid with numbers", "column123", false},
|
||||
{"valid underscore", "_private", false},
|
||||
{"valid mixed", "my_Column_123", false},
|
||||
{"invalid with space", "column name", true},
|
||||
{"invalid with dash", "column-name", true},
|
||||
{"invalid with dot", "table.column", true},
|
||||
{"invalid with quote", "column'name", true},
|
||||
{"invalid with semicolon", "column;DROP", true},
|
||||
{"invalid starts with number", "123column", true},
|
||||
{"invalid SQL keyword injection", "col); DROP TABLE users; --", true},
|
||||
{"empty string", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSQLIdentifier(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateSQLIdentifier(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEscapeSQLString tests SQL string escaping
|
||||
func TestEscapeSQLString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"no quotes", "hello", "hello"},
|
||||
{"single quote", "hello'world", "hello''world"},
|
||||
{"multiple quotes", "it's a 'test'", "it''s a ''test''"},
|
||||
{"only quote", "'", "''"},
|
||||
{"empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := escapeSQLString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("escapeSQLString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcatWsSQLInjectionPrevention tests that ConcatWs escapes quotes
|
||||
func TestConcatWsSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sep string
|
||||
fields []Field
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe separator",
|
||||
sep: ", ",
|
||||
fields: []Field{"col1", "col2"},
|
||||
contains: "concat_ws(', ',col1, col2)",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes",
|
||||
sep: "', (SELECT password FROM users), '",
|
||||
fields: []Field{"col1"},
|
||||
contains: "concat_ws(''', (SELECT password FROM users), ''',col1)",
|
||||
},
|
||||
{
|
||||
name: "single quote",
|
||||
sep: "'",
|
||||
fields: []Field{"col1"},
|
||||
contains: "concat_ws('''',col1)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConcatWs(tt.sep, tt.fields...)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("ConcatWs(%q, %v) = %q, should contain %q", tt.sep, tt.fields, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringAggSQLInjectionPrevention tests that StringAgg escapes quotes
|
||||
func TestStringAggSQLInjectionPrevention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exp string
|
||||
sep string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "safe parameters",
|
||||
exp: "column_name",
|
||||
sep: ", ",
|
||||
contains: "string_agg(column_name,', ')",
|
||||
},
|
||||
{
|
||||
name: "escaped quotes in separator",
|
||||
sep: "'; DROP TABLE users; --",
|
||||
exp: "col",
|
||||
contains: "string_agg(col,'''; DROP TABLE users; --')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StringAgg(tt.exp, tt.sep)
|
||||
if !strings.Contains(string(result), tt.contains) {
|
||||
t.Errorf("StringAgg(%q, %q) = %q, should contain %q", tt.exp, tt.sep, result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStringAggCastSQLInjectionPrevention tests that StringAggCast escapes quotes
|
||||
func TestStringAggCastSQLInjectionPrevention(t *testing.T) {
|
||||
result := StringAggCast("column", "'; DROP TABLE")
|
||||
expected := "string_agg(cast(column as varchar),'''; DROP TABLE')"
|
||||
if string(result) != expected {
|
||||
t.Errorf("StringAggCast should escape quotes, got %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDateTruncValidation tests DateTrunc level validation
|
||||
func TestDateTruncValidation(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
validLevels := []string{
|
||||
"microseconds", "milliseconds", "second", "minute", "hour",
|
||||
"day", "week", "month", "quarter", "year", "decade", "century", "millennium",
|
||||
}
|
||||
|
||||
for _, level := range validLevels {
|
||||
t.Run("valid_"+level, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("DateTrunc(%q) should not panic for valid level", level)
|
||||
}
|
||||
}()
|
||||
result := field.DateTrunc(level, "truncated")
|
||||
if !strings.Contains(string(result), level) {
|
||||
t.Errorf("DateTrunc result should contain level %q", level)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test case-insensitive
|
||||
t.Run("case_insensitive", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("DateTrunc should accept uppercase level")
|
||||
}
|
||||
}()
|
||||
result := field.DateTrunc("MONTH", "truncated")
|
||||
if !strings.Contains(strings.ToLower(string(result)), "month") {
|
||||
t.Errorf("DateTrunc should normalize case")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDateTruncInvalidLevel tests that DateTrunc panics on invalid level
|
||||
func TestDateTruncInvalidLevel(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
invalidLevels := []string{
|
||||
"invalid",
|
||||
"'; DROP TABLE users; --",
|
||||
"day; DELETE FROM",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, level := range invalidLevels {
|
||||
t.Run("invalid_"+level, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc(%q, 'alias') should panic for invalid level", level)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc(level, "truncated")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDateTruncInvalidAlias tests that DateTrunc panics on invalid alias
|
||||
func TestDateTruncInvalidAlias(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
|
||||
invalidAliases := []string{
|
||||
"alias name",
|
||||
"alias-name",
|
||||
"'; DROP TABLE",
|
||||
"123alias",
|
||||
"alias.name",
|
||||
}
|
||||
|
||||
for _, alias := range invalidAliases {
|
||||
t.Run("invalid_alias_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc('day', %q) should panic for invalid alias", alias)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc("day", alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTsRankValidation tests TsRank alias validation
|
||||
func TestTsRankValidation(t *testing.T) {
|
||||
field := Field("search_vector")
|
||||
|
||||
validAliases := []string{"rank", "score", "relevance_123", "_rank"}
|
||||
for _, alias := range validAliases {
|
||||
t.Run("valid_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("TsRank should not panic for valid alias %q", alias)
|
||||
}
|
||||
}()
|
||||
result := field.TsRank("query", alias)
|
||||
if !strings.Contains(string(result), alias) {
|
||||
t.Errorf("TsRank result should contain alias %q", alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
invalidAliases := []string{
|
||||
"'; DROP TABLE",
|
||||
"rank name",
|
||||
"rank-name",
|
||||
"123rank",
|
||||
}
|
||||
|
||||
for _, alias := range invalidAliases {
|
||||
t.Run("invalid_"+alias, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("TsRank should panic for invalid alias %q", alias)
|
||||
}
|
||||
}()
|
||||
field.TsRank("query", alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRowNumberValidation tests ROW_NUMBER alias validation
|
||||
func TestRowNumberValidation(t *testing.T) {
|
||||
field := Field("id")
|
||||
|
||||
t.Run("valid_alias", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("RowNumber should not panic for valid alias: %v", r)
|
||||
}
|
||||
}()
|
||||
result := field.RowNumber("row_num")
|
||||
if !strings.Contains(string(result), "row_num") {
|
||||
t.Errorf("RowNumber result should contain alias")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_alias", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("RowNumber should panic for invalid alias")
|
||||
}
|
||||
}()
|
||||
field.RowNumber("'; DROP TABLE")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSQLInjectionAttackVectors tests common SQL injection patterns
|
||||
func TestSQLInjectionAttackVectors(t *testing.T) {
|
||||
attacks := []string{
|
||||
"'; DROP TABLE users; --",
|
||||
"' OR '1'='1",
|
||||
"'; DELETE FROM users WHERE '1'='1",
|
||||
"1'; UPDATE users SET password='hacked'; --",
|
||||
"admin'--",
|
||||
"' UNION SELECT * FROM passwords--",
|
||||
}
|
||||
|
||||
t.Run("DateTrunc_alias_protection", func(t *testing.T) {
|
||||
field := Field("created_at")
|
||||
for _, attack := range attacks {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("DateTrunc should prevent SQL injection: %q", attack)
|
||||
}
|
||||
}()
|
||||
field.DateTrunc("day", attack)
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TsRank_alias_protection", func(t *testing.T) {
|
||||
field := Field("search_vector")
|
||||
for _, attack := range attacks {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("TsRank should prevent SQL injection: %q", attack)
|
||||
}
|
||||
}()
|
||||
field.TsRank("query", attack)
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcatWs_separator_escaping", func(t *testing.T) {
|
||||
for _, attack := range attacks {
|
||||
result := ConcatWs(attack, Field("col1"))
|
||||
// Check that quotes are escaped (doubled)
|
||||
if strings.Contains(attack, "'") && !strings.Contains(string(result), "''") {
|
||||
t.Errorf("ConcatWs should escape quotes in attack: %q", attack)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func (t *Table) Update() UpdateClause {
|
||||
return qb
|
||||
}
|
||||
|
||||
// Detlete table statement
|
||||
// Delete table statement
|
||||
func (t *Table) Delete() DeleteCluase {
|
||||
qb := &deleteQry{
|
||||
debug: t.debug,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package branchuser
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package comment
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package employee
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package post
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package db
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package user
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Code generated by code.patial.tech/go/pgm/cmd DO NOT EDIT.
|
||||
// Code generated by code.patial.tech/go/pgm/cmd vdev on 2025-11-16 04:05:43 DO NOT EDIT.
|
||||
|
||||
package usersession
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package playground
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"code.patial.tech/go/pgm/playground/db"
|
||||
@@ -26,6 +28,21 @@ func TestUpdateQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateQueryValidation(t *testing.T) {
|
||||
// Test that UPDATE without Set() returns error
|
||||
err := db.User.Update().
|
||||
Where(user.Email.Eq("aa@aa.com")).
|
||||
Exec(context.Background())
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when calling Exec() without Set(), got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no columns to update") {
|
||||
t.Errorf("Expected error message to contain 'no columns to update', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUpdateQuery-12 2004985 592.2 ns/op 1176 B/op 20 allocs/op
|
||||
func BenchmarkUpdateQuery(b *testing.B) {
|
||||
for b.Loop() {
|
||||
|
||||
9
qry.go
9
qry.go
@@ -92,13 +92,14 @@ func (c *CondGroup) Condition(args *[]any, argIdx int) string {
|
||||
defer putSB(sb)
|
||||
|
||||
sb.WriteString("(")
|
||||
currentIdx := argIdx
|
||||
for i, cond := range c.cond {
|
||||
if i == 0 {
|
||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
||||
} else {
|
||||
if i > 0 {
|
||||
sb.WriteString(c.op)
|
||||
sb.WriteString(cond.Condition(args, argIdx+i))
|
||||
}
|
||||
argsLenBefore := len(*args)
|
||||
sb.WriteString(cond.Condition(args, currentIdx))
|
||||
currentIdx += len(*args) - argsLenBefore
|
||||
}
|
||||
sb.WriteString(")")
|
||||
return sb.String()
|
||||
|
||||
@@ -19,6 +19,12 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// WARNING: DELETE without WHERE clause will delete ALL rows in the table.
|
||||
// Always use Where() to specify conditions unless you intentionally want to delete all rows.
|
||||
// Example:
|
||||
// table.Delete().Where(field.Eq(value)).Exec(ctx) // Safe
|
||||
// table.Delete().Exec(ctx) // Deletes ALL rows!
|
||||
|
||||
func (q *deleteQry) Where(cond ...Conditioner) WhereOrExec {
|
||||
q.condition = append(q.condition, cond...)
|
||||
return q
|
||||
|
||||
@@ -5,6 +5,7 @@ package pgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -56,6 +57,9 @@ func (q *updateQry) Where(cond ...Conditioner) WhereOrExec {
|
||||
}
|
||||
|
||||
func (q *updateQry) Exec(ctx context.Context) error {
|
||||
if len(q.cols) == 0 {
|
||||
return errors.New("update query has no columns to update: call Set() before Exec()")
|
||||
}
|
||||
_, err := poolPGX.Load().Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -65,6 +69,9 @@ func (q *updateQry) Exec(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (q *updateQry) ExecTx(ctx context.Context, tx pgx.Tx) error {
|
||||
if len(q.cols) == 0 {
|
||||
return errors.New("update query has no columns to update: call Set() before ExecTx()")
|
||||
}
|
||||
_, err := tx.Exec(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user