feature: verify tokens
This commit is contained in:
@@ -4,15 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitserver.in/patialtech/rano/config"
|
||||
"gitserver.in/patialtech/rano/db"
|
||||
"gitserver.in/patialtech/rano/db/ent/user"
|
||||
"gitserver.in/patialtech/rano/mailer"
|
||||
"gitserver.in/patialtech/rano/mailer/message"
|
||||
"gitserver.in/patialtech/rano/util/crypto"
|
||||
"gitserver.in/patialtech/rano/util/logger"
|
||||
"gitserver.in/patialtech/rano/util/validate"
|
||||
)
|
||||
|
||||
@@ -36,17 +37,17 @@ var (
|
||||
//
|
||||
// will return created userID on success
|
||||
func Create(ctx context.Context, inp *CreateInput) (int64, error) {
|
||||
// check for nil inp
|
||||
// Check for nil input.
|
||||
if inp == nil {
|
||||
return 0, ErrCreateInpNil
|
||||
}
|
||||
|
||||
// validate
|
||||
// Validate struct.
|
||||
if err := validate.Struct(inp); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// compare pwd and comparePwd
|
||||
// Compare pwd and comparePwd.
|
||||
if inp.Pwd != inp.ConfirmPwd {
|
||||
return 0, ErrWrongConfirmPwd
|
||||
}
|
||||
@@ -56,42 +57,73 @@ func Create(ctx context.Context, inp *CreateInput) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// save record to DB
|
||||
client := db.Client()
|
||||
u, err := client.User.Create().
|
||||
// Begin a transaction.
|
||||
tx, err := db.Client().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Save User to DB.
|
||||
u, err := tx.User.Create().
|
||||
SetEmail(inp.Email).
|
||||
SetPwdHash(h).
|
||||
SetPwdSalt(salt).
|
||||
SetFirstName(inp.FirstName).
|
||||
SetMiddleName(inp.MiddleName).
|
||||
SetLastName(inp.LastName).
|
||||
SetStatus(user.StatusActive).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err, slog.String("ref", "user: create error"))
|
||||
return 0, errors.New("failed to create user")
|
||||
}
|
||||
|
||||
// email
|
||||
// Get a new email-verification token
|
||||
tokenDuration := time.Hour * 6
|
||||
token, err := newTokenToVerifyEmail(u.ID, tokenDuration)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Save token to DB
|
||||
err = tx.VerifyToken.Create().
|
||||
SetToken(token).
|
||||
SetExpiresAt(time.Now().Add(tokenDuration).UTC()).
|
||||
SetPurpose("VerifyEmail").
|
||||
SetUserID(u.ID).Exec(ctx)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
name := fullName(inp.FirstName, inp.MiddleName, inp.LastName)
|
||||
// Send a welcome email with a link to verigy email-address.
|
||||
err = mailer.Send(
|
||||
[]mail.Address{
|
||||
{Name: inp.FullName(), Address: inp.Email},
|
||||
{Name: name, Address: inp.Email},
|
||||
},
|
||||
&message.Welcome{
|
||||
Name: inp.FullName(),
|
||||
Name: name,
|
||||
VerifyURL: config.VerifyEmailURL(token),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error(err, slog.String("ref", "user: send welcome email"))
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// ALL Done!
|
||||
// Created a new user in system.
|
||||
return u.ID, nil
|
||||
}
|
||||
|
||||
func (inp *CreateInput) FullName() string {
|
||||
if inp == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s %s %s", inp.FirstName, inp.MiddleName, inp.LastName)
|
||||
func fullName(fName, mName, lName string) string {
|
||||
name := fmt.Sprintf("%s %s %s", fName, mName, lName)
|
||||
return strings.Join(strings.Fields(name), " ")
|
||||
}
|
||||
|
@@ -3,6 +3,8 @@ package user
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v7"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
@@ -22,12 +24,12 @@ func TestCreate(t *testing.T) {
|
||||
|
||||
t.Run("create", func(t *testing.T) {
|
||||
if _, err := Create(context.Background(), &CreateInput{
|
||||
Email: "aa@aa.com",
|
||||
Email: gofakeit.Email(),
|
||||
Pwd: "pwd123",
|
||||
ConfirmPwd: "pwd123",
|
||||
FirstName: "Ankit",
|
||||
MiddleName: "Singh",
|
||||
LastName: "Patial",
|
||||
FirstName: gofakeit.FirstName(),
|
||||
MiddleName: gofakeit.MiddleName(),
|
||||
LastName: gofakeit.LastName(),
|
||||
RoleID: 1,
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
|
13
pkg/user/password.go
Normal file
13
pkg/user/password.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package user
|
||||
|
||||
// EmailResetPWD link to user to reset password
|
||||
func EmailResetPWD(email string) {
|
||||
// send Password reset instructionss
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// UpdatePWD in database
|
||||
func UpdatePWD(token, email, pwd, confirmPWD string) error {
|
||||
// update pwd in DB
|
||||
panic("not implemented")
|
||||
}
|
163
pkg/user/session.go
Normal file
163
pkg/user/session.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"gitserver.in/patialtech/rano/config"
|
||||
"gitserver.in/patialtech/rano/db"
|
||||
"gitserver.in/patialtech/rano/db/ent"
|
||||
"gitserver.in/patialtech/rano/db/ent/user"
|
||||
"gitserver.in/patialtech/rano/graph/model"
|
||||
"gitserver.in/patialtech/rano/util/crypto"
|
||||
"gitserver.in/patialtech/rano/util/logger"
|
||||
)
|
||||
|
||||
type (
|
||||
SessionUser struct {
|
||||
ID string
|
||||
Email string
|
||||
Name string
|
||||
RoleID int
|
||||
}
|
||||
AuthUser = model.AuthUser
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCred = errors.New("invalid email or password")
|
||||
ErrAccountNotActive = errors.New("account is not active")
|
||||
ErrAccountLocked = errors.New("account is locked, please try after sometime")
|
||||
ErrUnexpected = errors.New("unexpected error has happened")
|
||||
)
|
||||
|
||||
func CtxWithUser(ctx context.Context, u *AuthUser) context.Context {
|
||||
return context.WithValue(ctx, config.AuthUserCtxKey, &SessionUser{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
RoleID: u.RoleID,
|
||||
})
|
||||
}
|
||||
|
||||
func CtxUser(ctx context.Context) *SessionUser {
|
||||
u, _ := ctx.Value(config.AuthUserCtxKey).(*SessionUser)
|
||||
return u
|
||||
}
|
||||
|
||||
// NewSession for user.
|
||||
//
|
||||
// Authenticated
|
||||
func NewSession(ctx context.Context, email, pwd string) (*AuthUser, error) {
|
||||
// authenticate.
|
||||
u, err := authenticate(ctx, email, pwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 30 day token life
|
||||
until := time.Now().Add(time.Hour * 24 * 30).UTC()
|
||||
|
||||
// create sesion entry in db
|
||||
db.Client().UserSession.Create().
|
||||
SetUserID(u.ID).
|
||||
SetIssuedAt(time.Now().UTC()).
|
||||
SetExpiresAt(until).
|
||||
SetIP("").
|
||||
SetUserAgent("")
|
||||
|
||||
return &AuthUser{
|
||||
Name: fullName(u.FirstName, *u.MiddleName, u.LastName),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RemoveSession entry from DB
|
||||
func RemoveSession(sID uint) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// authenticate user against DB
|
||||
func authenticate(ctx context.Context, email, pwd string) (*ent.User, error) {
|
||||
client := db.Client()
|
||||
|
||||
// incident email attr
|
||||
attrEmail := slog.String("email", email)
|
||||
|
||||
// get user by given email
|
||||
u, err := client.User.
|
||||
Query().
|
||||
Where(user.EmailEQ(email)).
|
||||
Select(
|
||||
user.FieldEmail, user.FieldPwdHash, user.FieldPwdSalt,
|
||||
user.FieldLoginFailedCount, user.FieldLoginLockedUntil, user.FieldLoginAttemptOn,
|
||||
user.FieldFirstName, user.FieldMiddleName, user.FieldLastName,
|
||||
user.FieldStatus,
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
logger.Incident(ctx, "Authenticate", "wrong email", attrEmail)
|
||||
return nil, ErrInvalidCred
|
||||
}
|
||||
|
||||
logger.Error(err)
|
||||
return nil, ErrUnexpected
|
||||
}
|
||||
|
||||
// check account is ready for authentication
|
||||
// ensure that user account is active or perform other needed checks
|
||||
if u.Status != user.StatusActive {
|
||||
logger.Incident(ctx, "Authenticate", "account issue", attrEmail)
|
||||
return nil, ErrAccountNotActive
|
||||
}
|
||||
|
||||
// check account is locked
|
||||
lck := u.LoginLockedUntil
|
||||
now := time.Now().UTC()
|
||||
if lck != nil && now.Before(lck.UTC()) {
|
||||
logger.Incident(ctx, "Authenticate", "account locked", attrEmail)
|
||||
return nil, ErrAccountLocked
|
||||
}
|
||||
|
||||
upQry := client.User.UpdateOneID(u.ID)
|
||||
// compare password
|
||||
// in-case password is wrong, lets increment failed attempt
|
||||
if !crypto.ComparePasswordHash(pwd, u.PwdHash, u.PwdSalt) {
|
||||
var locked bool
|
||||
u.LoginFailedCount++
|
||||
upQry.
|
||||
SetLoginAttemptOn(time.Now().UTC()).
|
||||
SetLoginFailedCount(u.LoginFailedCount)
|
||||
|
||||
// lock user if count is more that 4
|
||||
if u.LoginFailedCount > 4 {
|
||||
locked = true
|
||||
upQry.SetLoginLockedUntil(time.Now().Add(time.Hour * 6).UTC())
|
||||
}
|
||||
|
||||
// update user login attempt status
|
||||
if err = upQry.Exec(ctx); err != nil {
|
||||
return nil, ErrUnexpected
|
||||
}
|
||||
|
||||
if locked {
|
||||
return nil, ErrAccountLocked
|
||||
}
|
||||
|
||||
return nil, ErrInvalidCred
|
||||
}
|
||||
|
||||
if u.LoginFailedCount > 0 {
|
||||
u.LoginFailedCount = 0
|
||||
upQry.ClearLoginFailedCount()
|
||||
if err := upQry.Exec(ctx); err != nil {
|
||||
logger.Error(err, attrEmail)
|
||||
}
|
||||
}
|
||||
|
||||
// let's not get them out
|
||||
u.PwdHash = ""
|
||||
u.PwdSalt = ""
|
||||
return u, nil
|
||||
}
|
50
pkg/user/token.go
Normal file
50
pkg/user/token.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gitserver.in/patialtech/rano/util/uid"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrExpiredToken = errors.New("expired token")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// newTokenToVerifyEmail for a user for given duration
|
||||
func newTokenToVerifyEmail(userID int64, d time.Duration) (string, error) {
|
||||
expiresAt := time.Now().Add(d).UTC().UnixMilli()
|
||||
return uid.Encode([]uint64{
|
||||
uint64(userID),
|
||||
1, // identifies that its token to verify email
|
||||
uint64(expiresAt),
|
||||
})
|
||||
}
|
||||
|
||||
// tokenToVerifyEmail will check for valid email token that is yet not expired
|
||||
//
|
||||
// returns userID on success
|
||||
func tokenToVerifyEmail(token string) (int64, error) {
|
||||
ids, err := uid.Decode(token)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// slice must have 3 entries
|
||||
if len(ids) != 3 {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
|
||||
// must be an email verify token
|
||||
if ids[1] != 1 {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
|
||||
// check expiry
|
||||
if int64(ids[2]) < time.Now().UTC().UnixMilli() {
|
||||
return 0, ErrExpiredToken
|
||||
}
|
||||
|
||||
return int64(ids[0]), nil
|
||||
}
|
43
pkg/user/token_test.go
Normal file
43
pkg/user/token_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func FuzzNewVerifyEmailToken(f *testing.F) {
|
||||
f.Add(int64(123))
|
||||
f.Fuzz(func(t *testing.T, userID int64) {
|
||||
_, err := newTokenToVerifyEmail(userID, time.Millisecond*100)
|
||||
if err != nil {
|
||||
t.Errorf("failed for input %d, %v", userID, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmailToken(t *testing.T) {
|
||||
uID := int64(1234)
|
||||
// create a token
|
||||
t1, err := newTokenToVerifyEmail(uID, time.Millisecond*100)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// let decode token
|
||||
id, err := tokenToVerifyEmail(t1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
} else if uID != id {
|
||||
t.Error("uid and decoded id are not same, ", uID, "!=", id)
|
||||
}
|
||||
|
||||
// lets wait and try decode again
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
_, err = tokenToVerifyEmail(t1)
|
||||
if !errors.Is(err, ErrExpiredToken) {
|
||||
t.Error("expected expired token error")
|
||||
}
|
||||
}
|
60
pkg/user/verify.go
Normal file
60
pkg/user/verify.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"gitserver.in/patialtech/rano/db"
|
||||
"gitserver.in/patialtech/rano/db/ent"
|
||||
"gitserver.in/patialtech/rano/db/ent/verifytoken"
|
||||
"gitserver.in/patialtech/rano/util/logger"
|
||||
)
|
||||
|
||||
// VerifyEmailAddress by a token
|
||||
func VerifyEmailAddress(ctx context.Context, token string) error {
|
||||
// decode token
|
||||
uid, err := tokenToVerifyEmail(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := db.Client()
|
||||
|
||||
// get token from DB
|
||||
vt, err := client.VerifyToken.Query().Where(verifytoken.TokenEQ(token)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
logger.Error(err, slog.String("ref", "pkg/user/verify.VerifyEmail"))
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
// all good, lets do the following
|
||||
// 1. Update user email verify status
|
||||
// 2. Remvoe token from DB
|
||||
// do it in a transaction
|
||||
|
||||
tx, err := client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return errors.New("unexpected error")
|
||||
}
|
||||
|
||||
// update user email verify status
|
||||
if err = tx.User.UpdateOneID(uid).SetEmailVerified(true).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove token from DB
|
||||
if err = tx.VerifyToken.DeleteOneID(vt.ID).Exec(ctx); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// we are all done now,
|
||||
// let's commit
|
||||
return tx.Commit()
|
||||
}
|
Reference in New Issue
Block a user