Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #14

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ go get github.com/vearne/ratelimit
```
### Usage
#### 1. create redis.Client
with "github.com/go-redis/redis"
with "github.com/redis/go-redis"
Supports both redis master-slave mode and cluster mode
```
client := redis.NewClient(&redis.Options{
Expand Down Expand Up @@ -132,15 +132,17 @@ package main
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
"github.com/vearne/ratelimit"
"github.com/vearne/ratelimit/counter"
"github.com/vearne/ratelimit/tokenbucket"
slog "github.com/vearne/simplelog"
"sync"
"time"
)

func consume(r ratelimit.Limiter, group *sync.WaitGroup,
c *ratelimit.Counter, targetCount int) {
c *counter.Counter, targetCount int) {
defer group.Done()
var ok bool
for {
Expand Down Expand Up @@ -168,7 +170,7 @@ func main() {
DB: 0, // use default DB
})

limiter, err := ratelimit.NewTokenBucketRateLimiter(
limiter, err := tokenbucket.NewTokenBucketRateLimiter(
context.Background(),
client,
"key:token",
Expand All @@ -184,7 +186,7 @@ func main() {

var wg sync.WaitGroup
total := 50
counter := ratelimit.NewCounter()
counter := counter.NewCounter()
start := time.Now()
for i := 0; i < 10; i++ {
wg.Add(1)
Expand All @@ -197,7 +199,7 @@ func main() {
```

### Dependency
[go-redis/redis](https://github.com/go-redis/redis)
[redis/go-redis](https://github.com/redis/go-redis)

### Thanks
The development of the module was inspired by the Reference 1.
Expand Down
14 changes: 8 additions & 6 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ go get github.com/vearne/ratelimit
```
## 用法
### 1. 创建 redis.Client
依赖 "github.com/go-redis/redis"
依赖 "github.com/redis/go-redis"
同时支持redis 主从模式和cluster模式
```
client := redis.NewClient(&redis.Options{
Expand Down Expand Up @@ -136,15 +136,17 @@ package main
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
"github.com/vearne/ratelimit"
"github.com/vearne/ratelimit/counter"
"github.com/vearne/ratelimit/tokenbucket"
slog "github.com/vearne/simplelog"
"sync"
"time"
)

func consume(r ratelimit.Limiter, group *sync.WaitGroup,
c *ratelimit.Counter, targetCount int) {
c *counter.Counter, targetCount int) {
defer group.Done()
var ok bool
for {
Expand Down Expand Up @@ -172,7 +174,7 @@ func main() {
DB: 0, // use default DB
})

limiter, err := ratelimit.NewTokenBucketRateLimiter(
limiter, err := tokenbucket.NewTokenBucketRateLimiter(
context.Background(),
client,
"key:token",
Expand All @@ -188,7 +190,7 @@ func main() {

var wg sync.WaitGroup
total := 50
counter := ratelimit.NewCounter()
counter := counter.NewCounter()
start := time.Now()
for i := 0; i < 10; i++ {
wg.Add(1)
Expand All @@ -200,7 +202,7 @@ func main() {
}
```
### 依赖
[go-redis/redis](https://github.com/go-redis/redis)
[redis/go-redis](https://github.com/redis/go-redis)

### 致谢
模块的开发受到了资料1的启发,在此表示感谢
Expand Down
28 changes: 12 additions & 16 deletions alg.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ return increment
key ->
token_count -> {token_count}
updateTime -> {lastUpdateTime}* 1000000 + {microsecond}
*/
*/

const TokenBucketScript = `
local bucket = KEYS[1]
Expand Down Expand Up @@ -85,14 +85,12 @@ end
return count
`


/*
key Type: string

// updateTime
key -> {lastUpdateTime}* 1000000 + {microsecond}
key Type: string

*/
// updateTime
key -> {lastUpdateTime}* 1000000 + {microsecond}
*/
const LeakyBucketScript = `
local bucket = KEYS[1]
local interval = tonumber(ARGV[1])
Expand All @@ -117,15 +115,13 @@ end
return count
`


var (
algMap map[int]string
AlgMap map[int]string
)


func init(){
algMap = make(map[int]string)
algMap[CounterAlg] = counterScript
algMap[TokenBucketAlg] = TokenBucketScript
algMap[LeakyBucketAlg] = LeakyBucketScript
}
func init() {
AlgMap = make(map[int]string)
AlgMap[CounterAlg] = counterScript
AlgMap[TokenBucketAlg] = TokenBucketScript
AlgMap[LeakyBucketAlg] = LeakyBucketScript
}
2 changes: 1 addition & 1 deletion counter.go → counter/counter.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ratelimit
package counter

import "sync"

Expand Down
49 changes: 33 additions & 16 deletions counter_limiter.go → counter/counter_limiter.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package ratelimit
package counter

import (
"context"
"crypto/sha1"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
"github.com/vearne/ratelimit"
slog "github.com/vearne/simplelog"
"golang.org/x/sync/singleflight"
"golang.org/x/time/rate"
Expand All @@ -14,7 +15,7 @@ import (

//nolint:govet
type CounterLimiter struct {
BaseRateLimiter
ratelimit.BaseRateLimiter
duration time.Duration
throughput int
batchSize int
Expand All @@ -29,9 +30,11 @@ type CounterLimiter struct {
antiDDoSLimiter *rate.Limiter
}

type Option func(*CounterLimiter)

func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string, duration time.Duration,
throughput int,
batchSize int) (Limiter, error) {
batchSize int, opts ...Option) (ratelimit.Limiter, error) {

_, err := client.Ping(ctx).Result()
if err != nil {
Expand All @@ -50,23 +53,35 @@ func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string
return nil, errors.New("batchSize must greater than 0")
}

script := algMap[CounterAlg]
script := ratelimit.AlgMap[ratelimit.CounterAlg]
scriptSHA1 := fmt.Sprintf("%x", sha1.Sum([]byte(script)))

r := CounterLimiter{
BaseRateLimiter: BaseRateLimiter{redisClient: client, scriptSHA1: scriptSHA1, key: key},
BaseRateLimiter: ratelimit.BaseRateLimiter{RedisClient: client, ScriptSHA1: scriptSHA1, Key: key},
duration: duration,
throughput: throughput,
batchSize: batchSize,
N: 0,
AntiDDoS: true,
}
r.interval = duration / time.Duration(throughput)
r.Interval = duration / time.Duration(throughput)

if !r.redisClient.ScriptExists(ctx, r.scriptSHA1).Val()[0] {
r.redisClient.ScriptLoad(ctx, script).Val()
// Loop through each option
for _, opt := range opts {
// Call the option giving the instantiated
opt(&r)
}

values, err := r.RedisClient.ScriptExists(ctx, r.ScriptSHA1).Result()
if err != nil {
return nil, err
}
if !values[0] {
_, err = r.RedisClient.ScriptLoad(ctx, script).Result()
if err != nil {
return nil, err
}
}
// 2x throughput
throughputPerSec := int(float64(throughput) / float64(duration/time.Second))
r.antiDDoSLimiter = rate.NewLimiter(rate.Limit(throughputPerSec*2), throughputPerSec*2)
Expand All @@ -75,8 +90,10 @@ func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string
}

// just for test
func (r *CounterLimiter) WithAntiDDos(antiDDoS bool) {
r.AntiDDoS = antiDDoS
func WithAntiDDos(antiDDoS bool) Option {
return func(r *CounterLimiter) {
r.AntiDDoS = antiDDoS
}
}

func (r *CounterLimiter) tryTakeFromLocal() bool {
Expand All @@ -101,7 +118,7 @@ func (r *CounterLimiter) Wait(ctx context.Context) (err error) {
}

deadline, ok := ctx.Deadline()
minWaitTime := r.interval
minWaitTime := r.Interval

slog.Debug("minWaitTime:%v", minWaitTime)
if ok {
Expand Down Expand Up @@ -143,11 +160,11 @@ func (r *CounterLimiter) Take(ctx context.Context) (bool, error) {
}

// 2. try to get from redis
_, err, _ := r.g.Do(r.key, func() (interface{}, error) {
x, err := r.redisClient.EvalSha(
_, err, _ := r.g.Do(r.Key, func() (interface{}, error) {
x, err := r.RedisClient.EvalSha(
ctx,
r.scriptSHA1,
[]string{r.key},
r.ScriptSHA1,
[]string{r.Key},
int(r.duration/time.Microsecond),
r.throughput,
r.batchSize,
Expand Down
107 changes: 107 additions & 0 deletions counter/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package counter

import (
"context"
"fmt"
"github.com/go-redis/redismock/v9"
"github.com/stretchr/testify/assert"
"log"
"testing"
"time"
)

const (
key = "key:count"
hashVal = "bdbede5669d5e48d6e6c2967aeed2f72f03868ac"
)

func MyMatch(expected, actual []interface{}) error {
expectedStr := fmt.Sprintf("%v", expected)
actualStr := fmt.Sprintf("%v", actual)
if expectedStr == actualStr {
return nil
}
log.Printf("expectedStr:%v, actualStr:%v", expectedStr, actualStr)
return fmt.Errorf("not equal, expectedStr:%s, actualStr:%s", expectedStr, actualStr)
}

func TestTakeFail(t *testing.T) {
db, mock := redismock.NewClientMock()

mock = mock.CustomMatch(MyMatch)
mock.ExpectPing().SetVal("PONG")

mock.ExpectScriptExists(hashVal).SetVal([]bool{true})
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(0))

limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
3,
2,
WithAntiDDos(false))
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}

ok, err := limiter.Take(context.Background())
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}
if !ok {
assert.Equal(t, ok, false)
}
}

func TestTakeSuccess(t *testing.T) {
db, mock := redismock.NewClientMock()

mock = mock.CustomMatch(MyMatch)
mock.ExpectPing().SetVal("PONG")

mock.ExpectScriptExists(hashVal).SetVal([]bool{true})
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(1))

limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
3,
2,
WithAntiDDos(false))
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}

ok, err := limiter.Take(context.Background())
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}
if !ok {
assert.Equal(t, ok, true)
}
}

func TestContextTimeOut(t *testing.T) {
db, mock := redismock.NewClientMock()
mock = mock.CustomMatch(MyMatch)
mock.ExpectPing().SetVal("PONG")

mock.ExpectScriptExists(hashVal).SetVal([]bool{true, true})
for i := 0; i < 1000; i++ {
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(0))
}

limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
3,
2,
WithAntiDDos(false))
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}

waitCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err = limiter.Wait(waitCtx)
assert.Contains(t, err.Error(), "timeout")
}
Loading
Loading