From f1601020b15f66bfeae39554801b5d9767a917a4 Mon Sep 17 00:00:00 2001 From: Ankit Patial Date: Fri, 20 Feb 2026 16:38:24 +0530 Subject: [PATCH] cluade code review changes --- crypto/ed25519.go | 4 ++-- date/date.go | 2 +- dotenv/parser.go | 2 +- dotenv/write.go | 9 ++------ email/gomail/send.go | 1 + email/transport_dump.go | 6 +++++- email/transport_smtp.go | 2 +- gz/gz.go | 12 ++++++++++- jwt/jwt.go | 34 +++++++++++++++-------------- open/open_windows.go | 2 +- request/pager.go | 18 +++++++++++----- request/payload.go | 48 ++++++++++++++++++++++------------------- request/query.go | 13 ++++++++++- response/reply.go | 13 ++++++++--- structs/structs.go | 18 +++++++++++----- uid/sqid.go | 36 ++++++++++++++++++++++++++++--- 16 files changed, 150 insertions(+), 70 deletions(-) diff --git a/crypto/ed25519.go b/crypto/ed25519.go index 5fc159a..ad3778c 100644 --- a/crypto/ed25519.go +++ b/crypto/ed25519.go @@ -65,7 +65,7 @@ func ParseEdPrivateKey(d []byte) (ed25519.PrivateKey, error) { case ed25519.PrivateKey: return pub, nil default: - return nil, errors.New("key type is not RSA") + return nil, errors.New("key type is not Ed25519") } } @@ -99,6 +99,6 @@ func ParseEdPublicKey(d []byte) (ed25519.PublicKey, error) { case ed25519.PublicKey: return pub, nil default: - return nil, errors.New("key type is not RSA") + return nil, errors.New("key type is not Ed25519") } } diff --git a/date/date.go b/date/date.go index 99f0a63..ca91a5f 100644 --- a/date/date.go +++ b/date/date.go @@ -39,7 +39,7 @@ func StartOfDay(date time.Time) time.Time { func EndOfDay(date time.Time) time.Time { year, month, day := date.Date() - return time.Date(year, month, day, 23, 59, 59, 0, date.Location()) + return time.Date(year, month, day, 23, 59, 59, 999999999, date.Location()) } func StartOfMonth(date time.Time) time.Time { diff --git a/dotenv/parser.go b/dotenv/parser.go index 5c932af..16177c5 100644 --- a/dotenv/parser.go +++ b/dotenv/parser.go @@ -255,7 +255,7 @@ func isLineEnd(r rune) bool { var ( escapeRegex = regexp.MustCompile(`\\.`) - expandVarRegex = regexp.MustCompile(`(\\)?(\$)(\()?\{?([A-Z0-9_]+)?\}?`) + expandVarRegex = regexp.MustCompile(`(\\)?(\$)(\()?\{?([A-Za-z0-9_]+)?\}?`) unescapeCharsRegex = regexp.MustCompile(`\\([^$])`) ) diff --git a/dotenv/write.go b/dotenv/write.go index 6afb650..9264a25 100644 --- a/dotenv/write.go +++ b/dotenv/write.go @@ -4,7 +4,6 @@ import ( "fmt" "os" "sort" - "strconv" "strings" ) @@ -20,7 +19,7 @@ func Write(envMap map[string]string, filename string) error { if err != nil { return err } - file, err := os.Create(filename) + file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } @@ -37,11 +36,7 @@ func Write(envMap map[string]string, filename string) error { func marshal(envMap map[string]string) (string, error) { lines := make([]string, 0, len(envMap)) for k, v := range envMap { - if d, err := strconv.Atoi(v); err == nil { - lines = append(lines, fmt.Sprintf(`%s=%d`, k, d)) - } else { - lines = append(lines, fmt.Sprintf(`%s="%s"`, k, doubleQuoteEscape(v))) - } + lines = append(lines, fmt.Sprintf(`%s="%s"`, k, doubleQuoteEscape(v))) } sort.Strings(lines) return strings.Join(lines, "\n"), nil diff --git a/email/gomail/send.go b/email/gomail/send.go index 58958b5..7384202 100644 --- a/email/gomail/send.go +++ b/email/gomail/send.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/mail" + "slices" ) // Sender is the interface that wraps the Send method. diff --git a/email/transport_dump.go b/email/transport_dump.go index c32d041..8a9240e 100644 --- a/email/transport_dump.go +++ b/email/transport_dump.go @@ -1,6 +1,7 @@ package email import ( + "fmt" "os" "path/filepath" @@ -19,7 +20,10 @@ func (DumpToTemp) Send(msg *Message) error { } dir := os.TempDir() - id, _ := uuid.NewV7() + id, err := uuid.NewV7() + if err != nil { + return fmt.Errorf("email: failed to generate UUID: %w", err) + } file := filepath.Join(dir, id.String()+".html") if err := os.WriteFile(file, []byte(msg.HtmlBody), 0440); err != nil { diff --git a/email/transport_smtp.go b/email/transport_smtp.go index 7beca77..9f2f3f5 100644 --- a/email/transport_smtp.go +++ b/email/transport_smtp.go @@ -87,6 +87,6 @@ func (t SMTP) Send(msg *Message) error { return err } - slog.Info("sent email %s" + msg.Subject) + slog.Info("sent email", slog.String("subject", msg.Subject)) return nil } diff --git a/gz/gz.go b/gz/gz.go index bd40cd6..38f4f5b 100644 --- a/gz/gz.go +++ b/gz/gz.go @@ -8,8 +8,14 @@ package gz import ( "bytes" "compress/gzip" + "errors" + "io" ) +// MaxDecompressedSize is the maximum allowed size for decompressed data (256MB). +// This prevents decompression bomb attacks. Override if you need larger outputs. +var MaxDecompressedSize int64 = 256 << 20 + func Zip(data []byte) ([]byte, error) { var b bytes.Buffer gz := gzip.NewWriter(&b) @@ -37,9 +43,13 @@ func UnZip(data []byte) ([]byte, error) { defer r.Close() // Ensure reader is closed to prevent resource leak var resB bytes.Buffer - if _, err := resB.ReadFrom(r); err != nil { + if _, err := io.Copy(&resB, io.LimitReader(r, MaxDecompressedSize+1)); err != nil { return nil, err } + if int64(resB.Len()) > MaxDecompressedSize { + return nil, errors.New("gz: decompressed data exceeds maximum allowed size") + } + return resB.Bytes(), nil } diff --git a/jwt/jwt.go b/jwt/jwt.go index 829a33c..3dd7e37 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -28,13 +28,14 @@ func Parse(key ed25519.PrivateKey, tokenString string, issuer string) (jwt.MapCl // SignEdDSA (Edwards-curve Digital Signature Algorithm, typically Ed25519) is an excellent, // modern choice for JWT signing—arguably safer and more efficient than both HS256 and traditional RSA/ECDSA. func SignEdDSA(key ed25519.PrivateKey, claims map[string]any, issuer string, d time.Duration) (string, error) { - cl := jwt.MapClaims{ - "iss": issuer, - "iat": jwt.NewNumericDate(time.Now().UTC()), - "exp": jwt.NewNumericDate(time.Now().Add(d)), - } + cl := jwt.MapClaims{} maps.Copy(cl, claims) + // Set standard claims after user claims to prevent override + cl["iss"] = issuer + cl["iat"] = jwt.NewNumericDate(time.Now().UTC()) + cl["exp"] = jwt.NewNumericDate(time.Now().Add(d)) + t := jwt.NewWithClaims(jwt.SigningMethodEdDSA, cl) return t.SignedString(key) } @@ -61,13 +62,14 @@ func ParseEdDSA(key ed25519.PrivateKey, tokenString string, issuer string) (jwt. } func SignHS256(secret []byte, claims map[string]any, issuer string, d time.Duration) (string, error) { - cl := jwt.MapClaims{ - "iss": issuer, - "iat": jwt.NewNumericDate(time.Now().UTC()), - "exp": jwt.NewNumericDate(time.Now().Add(d)), - } + cl := jwt.MapClaims{} maps.Copy(cl, claims) + // Set standard claims after user claims to prevent override + cl["iss"] = issuer + cl["iat"] = jwt.NewNumericDate(time.Now().UTC()) + cl["exp"] = jwt.NewNumericDate(time.Now().Add(d)) + t := jwt.NewWithClaims(jwt.SigningMethodHS256, cl) return t.SignedString(secret) } @@ -102,10 +104,12 @@ func ParseHS256(secret []byte, tokenString string, issuer string) (jwt.MapClaims func SignES256( key *ecdsa.PrivateKey, issuer, audience, subject string, d time.Duration, claims map[string]any, ) (string, error) { - cl := jwt.MapClaims{ - "iat": jwt.NewNumericDate(time.Now().UTC()), - "exp": jwt.NewNumericDate(time.Now().Add(d)), - } + cl := jwt.MapClaims{} + maps.Copy(cl, claims) + + // Set standard claims after user claims to prevent override + cl["iat"] = jwt.NewNumericDate(time.Now().UTC()) + cl["exp"] = jwt.NewNumericDate(time.Now().Add(d)) if issuer != "" { cl["iss"] = issuer @@ -119,8 +123,6 @@ func SignES256( cl["sub"] = subject } - maps.Copy(cl, claims) - t := jwt.NewWithClaims(jwt.SigningMethodES256, cl) return t.SignedString(key) } diff --git a/open/open_windows.go b/open/open_windows.go index 9513009..76c9a12 100644 --- a/open/open_windows.go +++ b/open/open_windows.go @@ -18,7 +18,7 @@ func cleaninput(input string) string { } func open(input string) *exec.Cmd { - cmd := exec.Command(runDll32, cmd, input) + cmd := exec.Command(runDll32, cmd, cleaninput(input)) // cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} return cmd } diff --git a/request/pager.go b/request/pager.go index b7a749e..685e67b 100644 --- a/request/pager.go +++ b/request/pager.go @@ -13,6 +13,9 @@ import ( "code.patial.tech/go/appcore/ptr" ) +// MaxPageSize is the maximum allowed page size to prevent resource exhaustion. +var MaxPageSize = 1000 + type Pager struct { OrderBy *string `json:"orderBy"` OrderAsc *bool `json:"orderAsc"` @@ -39,7 +42,9 @@ func (i *Pager) Offset() int { return 0 } - return (i.Page - 1) * i.Limit() + page := max(i.Page, 1) + + return (page - 1) * i.Limit() } func (i *Pager) Limit() int { @@ -63,14 +68,17 @@ func GetPager(r *http.Request) Pager { } if v := r.URL.Query().Get("pg"); v != "" { - if vv, err := strconv.Atoi(v); err == nil { - p.Page = int(vv) + if vv, err := strconv.Atoi(v); err == nil && vv > 0 { + p.Page = vv } } if v := r.URL.Query().Get("pg_s"); v != "" { - if vv, err := strconv.Atoi(v); err == nil { - p.Size = int(vv) + if vv, err := strconv.Atoi(v); err == nil && vv > 0 { + if vv > MaxPageSize { + vv = MaxPageSize + } + p.Size = vv } } diff --git a/request/payload.go b/request/payload.go index 7b07aed..5226504 100644 --- a/request/payload.go +++ b/request/payload.go @@ -15,10 +15,10 @@ import ( "code.patial.tech/go/appcore/validate" ) -// MaxRequestBodySize is the maximum allowed size for request bodies (1MB default). +// MaxRequestBodySize is the maximum allowed size for request bodies (5MB default). // This prevents resource exhaustion attacks from large payloads. // Override this value if you need to accept larger requests. -var MaxRequestBodySize int64 = 1 << 20 // 1MB +var MaxRequestBodySize int64 = 5 << 20 // 5MB func FormField(r *http.Request, key string) (string, error) { f, err := Payload[map[string]any](r) @@ -49,22 +49,24 @@ func PayloadWithValidate[T any](r *http.Request) (T, error) { func Payload[T any](r *http.Request) (T, error) { var p T + if r.ContentLength > MaxRequestBodySize { + return p, errors.New("request body too large") + } + // Limit request body size to prevent resource exhaustion - limited := io.LimitReader(r.Body, MaxRequestBodySize) + limited := io.LimitReader(r.Body, MaxRequestBodySize+1) decoder := json.NewDecoder(limited) if err := decoder.Decode(&p); err != nil { - // Check if we hit the size limit - if err == io.EOF || err == io.ErrUnexpectedEOF { - // Try to read one more byte to see if there's more data - var buf [1]byte - if n, _ := limited.Read(buf[:]); n == 0 { - // We hit the limit - return p, errors.New("request body too large") - } - } return p, err } + + // Check if there's more data beyond the limit + var buf [1]byte + if n, _ := limited.Read(buf[:]); n > 0 { + return p, errors.New("request body too large") + } + return p, nil } @@ -72,21 +74,23 @@ func Payload[T any](r *http.Request) (T, error) { // This is useful when you want to decode into an existing variable. // The request body size is limited by MaxRequestBodySize to prevent DoS attacks. func DecodeJSON(r *http.Request, v any) error { + if r.ContentLength > MaxRequestBodySize { + return errors.New("request body too large") + } + // Limit request body size to prevent resource exhaustion - limited := io.LimitReader(r.Body, MaxRequestBodySize) + limited := io.LimitReader(r.Body, MaxRequestBodySize+1) decoder := json.NewDecoder(limited) if err := decoder.Decode(v); err != nil { - // Check if we hit the size limit - if err == io.EOF || err == io.ErrUnexpectedEOF { - // Try to read one more byte to see if there's more data - var buf [1]byte - if n, _ := limited.Read(buf[:]); n == 0 { - // We hit the limit - return errors.New("request body too large") - } - } return err } + + // Check if there's more data beyond the limit + var buf [1]byte + if n, _ := limited.Read(buf[:]); n > 0 { + return errors.New("request body too large") + } + return nil } diff --git a/request/query.go b/request/query.go index 73d5cd3..9553e95 100644 --- a/request/query.go +++ b/request/query.go @@ -35,10 +35,21 @@ func NumberParam[T Number](r *http.Request, key string) (T, error) { return T(n), nil } - // int param + // unsigned int param + if k >= reflect.Uint && k <= reflect.Uint64 { + n, err := strconv.ParseUint(p, 10, 64) + if err != nil { + return noop, fmt.Errorf("query param: %q is not a valid unsigned integer", key) + } + + return T(n), nil + } + + // signed int param n, err := strconv.ParseInt(p, 10, 64) if err != nil { return noop, fmt.Errorf("query param: %q is not a valid integer", key) } + return T(n), nil } diff --git a/response/reply.go b/response/reply.go index 8b0873c..ccd4264 100644 --- a/response/reply.go +++ b/response/reply.go @@ -59,10 +59,12 @@ func reply(w http.ResponseWriter, data any, p *request.Pager) { // json data... w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(Detail{ + if err := json.NewEncoder(w).Encode(Detail{ Data: data, Pager: p, - }) + }); err != nil { + slog.Error(err.Error()) + } } func BadRequest(w http.ResponseWriter, err error) { @@ -92,9 +94,14 @@ func SessionExpired(w http.ResponseWriter) { } } +// Deprecated: Use NotAuthorized instead. func NotAutorized(w http.ResponseWriter) { + NotAuthorized(w) +} + +func NotAuthorized(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusForbidden) _, writeErr := fmt.Fprint(w, "{\"error\": \"You are not authorized to perform this action\"}") if writeErr != nil { slog.Error(writeErr.Error()) diff --git a/structs/structs.go b/structs/structs.go index c36903e..e47e5c3 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -20,13 +20,21 @@ func Map(obj any) map[string]any { for i := range val.NumField() { fieldName := typ.Field(i).Name - fieldValueKind := val.Field(i).Kind() + field := val.Field(i) + fieldValueKind := field.Kind() var fieldValue any - if fieldValueKind == reflect.Struct { - fieldValue = Map(val.Field(i).Interface()) - } else { - fieldValue = val.Field(i).Interface() + switch fieldValueKind { + case reflect.Struct: + fieldValue = Map(field.Interface()) + case reflect.Pointer: + if !field.IsNil() && field.Elem().Kind() == reflect.Struct { + fieldValue = Map(field.Elem().Interface()) + } else { + fieldValue = field.Interface() + } + default: + fieldValue = field.Interface() } result[fieldName] = fieldValue diff --git a/uid/sqid.go b/uid/sqid.go index dc1a63e..a3e4165 100644 --- a/uid/sqid.go +++ b/uid/sqid.go @@ -5,15 +5,45 @@ package uid -import "github.com/sqids/sqids-go" +import ( + "sync" + + "github.com/sqids/sqids-go" +) type Service interface { SquiOptions() sqids.Options } +var ( + mu sync.Mutex + sqidInst *sqids.Sqids + sqidOpts *sqids.Options +) + +func getSqids(svc Service) (*sqids.Sqids, error) { + opts := svc.SquiOptions() + + mu.Lock() + defer mu.Unlock() + + if sqidInst != nil && sqidOpts != nil && opts.Alphabet == sqidOpts.Alphabet && opts.MinLength == sqidOpts.MinLength { + return sqidInst, nil + } + + s, err := sqids.New(opts) + if err != nil { + return nil, err + } + + sqidInst = s + sqidOpts = &opts + return s, nil +} + // Encode a slice of IDs into one unique ID func Encode(svc Service, ids ...uint64) (string, error) { - s, err := sqids.New(svc.SquiOptions()) + s, err := getSqids(svc) if err != nil { return "", err } @@ -23,7 +53,7 @@ func Encode(svc Service, ids ...uint64) (string, error) { // Decode an ID back to slice of IDs func Decode(svc Service, id string) ([]uint64, error) { - s, err := sqids.New(svc.SquiOptions()) + s, err := getSqids(svc) if err != nil { return nil, err }