package middleware

import (
	"fmt"
	"net"
	"net/http"
	"strings"
	"sync"
	"time"

	"golang.org/x/time/rate"
)

// ipLimiter
type ipLimiter struct {
	limiter  *rate.Limiter // IP限流器
	lastSeen time.Time     // 上次访问时间 用于清理一直无访问的ip
}

// IPRateLimiter基于IP 的限流器
type IPRateLimiter struct {
	ips         map[string]*ipLimiter
	mu          sync.RWMutex // 读写锁
	rateLimit   rate.Limit   // 每秒允许的请求数
	burst       int          // 允许的突发请求数
	lastCleanup time.Time    // 上次清理时间 隔一段时间清理一次
}

// NewIPRateLimiter 创建限流器
func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter {
	return &IPRateLimiter{
		ips:         make(map[string]*ipLimiter),
		rateLimit:   r,
		burst:       b,
		lastCleanup: time.Now(),
	}
}

// GetLimiter 获取指定 IP 的限流器（不存在则创建）
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
	i.mu.Lock()
	defer i.mu.Unlock()

	// 定时清理
	if time.Since(i.lastCleanup) > 5*time.Minute {
		i.cleanup()
		i.lastCleanup = time.Now()
	}

	entry, exists := i.ips[ip]
	if !exists {
		entry = &ipLimiter{
			limiter:  rate.NewLimiter(i.rateLimit, i.burst),
			lastSeen: time.Now(),
		}
		i.ips[ip] = entry
	} else {
		entry.lastSeen = time.Now()
	}

	return entry.limiter
}

// cleanup 清理超过 1 小时未活跃的 IP
func (i *IPRateLimiter) cleanup() {
	now := time.Now()
	for ip, entry := range i.ips {
		if now.Sub(entry.lastSeen) > time.Hour {
			delete(i.ips, ip)
		}
	}
}

// Allow 判断是否允许 IP 继续访问
func (i *IPRateLimiter) Allow(ip string) bool {
	return i.GetLimiter(ip).Allow()
}

// getClientIP 获取客户端真实 IP
func getClientIP(r *http.Request) string {
	// 优先检查 Cloudflare 的 CF-Connecting-IP
	if cfIP := strings.TrimSpace(r.Header.Get("CF-Connecting-IP")); cfIP != "" {
		return cfIP
	}

	if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
		for _, ip := range strings.Split(forwarded, ",") {
			ip = strings.TrimSpace(ip)
			if ip != "" {
				return ip
			}
		}
	}
	if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); realIP != "" {
		return realIP
	}
	host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
	if err == nil {
		return host
	}
	return r.RemoteAddr
}

// formatFloat 转换 float 为整数字符串
func formatFloat(f float64) string {
	if f < 0 {
		f = 0
	}
	return fmt.Sprintf("%d", int(f))
}
