diff --git a/Makefile b/Makefile index 081d88e..b2167a0 100644 --- a/Makefile +++ b/Makefile @@ -2,3 +2,6 @@ run: go run ./cmd -o ./example/db ./example/schema.sql bench-select: go test ./example -bench BenchmarkSelect -memprofile memprofile.out -cpuprofile profile.out + +test: + go test ./playground diff --git a/playground/qry_select_test.go b/playground/qry_select_test.go index beba10c..4843b5b 100644 --- a/playground/qry_select_test.go +++ b/playground/qry_select_test.go @@ -60,7 +60,31 @@ func TestSelectWithHaving(t *testing.T) { } } -// BenchmarkSelect-12 668817 1753 ns/op 4442 B/op 59 allocs/op +func TestSelectWithJoin(t *testing.T) { + got := db.User.Select(user.Email, user.FirstName). + Join(db.UserSession, user.ID, usersession.UserID). + LeftJoin(db.BranchUser, user.ID, branchuser.UserID, pgm.Or(branchuser.RoleID.Eq("1"), branchuser.RoleID.Eq("2"))). + Where( + user.ID.Eq(3), + pgm.Or( + user.StatusID.Eq(4), + user.UpdatedAt.Eq(5), + ), + ). + Limit(10). + Offset(100). + String() + + expected := "SELECT users.email, users.first_name " + + "FROM users JOIN user_sessions ON users.id = user_sessions.user_id " + + "LEFT JOIN branch_users ON users.id = branch_users.user_id AND (branch_users.role_id = $1 OR branch_users.role_id = $2) " + + "WHERE users.id = $3 AND (users.status_id = $4 OR users.updated_at = $5) " + + "LIMIT 10 OFFSET 100" + if expected != got { + t.Errorf("\nexpected: %q\ngot: %q", expected, got) + } +} + // BenchmarkSelect-12 638901 1860 ns/op 4266 B/op 61 allocs/op func BenchmarkSelect(b *testing.B) { for b.Loop() { diff --git a/qry.go b/qry.go index e396dec..32a5337 100644 --- a/qry.go +++ b/qry.go @@ -21,10 +21,10 @@ type ( SelectClause interface { // Join and Inner Join are same - Join(m Table, t1Field, t2Field Field) SelectClause - LeftJoin(m Table, t1Field, t2Field Field) SelectClause - RightJoin(m Table, t1Field, t2Field Field) SelectClause - FullJoin(m Table, t1Field, t2Field Field) SelectClause + Join(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause + LeftJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause + RightJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause + FullJoin(m Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause CrossJoin(m Table) SelectClause WhereClause OrderByClause diff --git a/qry_select.go b/qry_select.go index e73b6ca..f71843a 100644 --- a/qry_select.go +++ b/qry_select.go @@ -62,23 +62,46 @@ func (t Table) Select(field ...Field) SelectClause { return qb } -func (q *selectQry) Join(t Table, t1Field, t2Field Field) SelectClause { - q.join = append(q.join, "JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String()) - return q +func (q *selectQry) Join(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { + return q.buildJoin(t, "JOIN", t1Field, t2Field, cond...) } -func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field) SelectClause { - q.join = append(q.join, "LEFT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String()) - return q +func (q *selectQry) LeftJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { + return q.buildJoin(t, "LEFT JOIN", t1Field, t2Field, cond...) } -func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field) SelectClause { - q.join = append(q.join, "RIGHT JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String()) - return q +func (q *selectQry) RightJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { + return q.buildJoin(t, "RIGHT JOIN", t1Field, t2Field, cond...) } -func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field) SelectClause { - q.join = append(q.join, "FULL JOIN "+t.Name+" ON "+t1Field.String()+" = "+t2Field.String()) +func (q *selectQry) FullJoin(t Table, t1Field, t2Field Field, cond ...Conditioner) SelectClause { + return q.buildJoin(t, "FULL JOIN", t1Field, t2Field, cond...) +} + +func (q *selectQry) buildJoin(t Table, joinKW string, t1Field, t2Field Field, cond ...Conditioner) SelectClause { + str := joinKW + " " + t.Name + " ON " + t1Field.String() + " = " + t2Field.String() + if len(cond) == 0 { // Join with no condition + q.join = append(q.join, str) + return q + } + + // Join has condition(s) + sb := getSB() + defer putSB(sb) + sb.Grow(len(str) * 2) + + sb.WriteString(str + " AND ") + + var argIdx int + for i, c := range cond { + argIdx = len(q.args) + if i > 0 { + sb.WriteString(" AND ") + } + sb.WriteString(c.Condition(&q.args, argIdx)) + } + + q.join = append(q.join, sb.String()) return q }