mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-22 04:39:19 +02:00
Fix concurrent map writes error in ohm.Select(). (#2943)
* Add unit test for ohm.tagsCache. * Fix concurrent map writes in ohm.Select(). --------- Co-authored-by: nobody <nobody@nowhere.mars>
This commit is contained in:
parent
10255bca83
commit
d20a835016
|
@ -2,9 +2,14 @@ package outbound_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/app/policy"
|
||||
"github.com/xtls/xray-core/app/proxyman"
|
||||
. "github.com/xtls/xray-core/app/proxyman/outbound"
|
||||
"github.com/xtls/xray-core/app/stats"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
|
@ -78,3 +83,91 @@ func TestOutboundWithStatCounter(t *testing.T) {
|
|||
t.Errorf("Expected conn to be CounterConnection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagsCache(t *testing.T) {
|
||||
|
||||
test_duration := 10 * time.Second
|
||||
threads_num := 50
|
||||
delay := 10 * time.Millisecond
|
||||
tags_prefix := "node"
|
||||
|
||||
tags := sync.Map{}
|
||||
counter := atomic.Uint64{}
|
||||
|
||||
ohm, err := New(context.Background(), &proxyman.OutboundConfig{})
|
||||
if err != nil {
|
||||
t.Error("failed to create outbound handler manager")
|
||||
}
|
||||
config := &core.Config{
|
||||
App: []*serial.TypedMessage{},
|
||||
}
|
||||
v, _ := core.New(config)
|
||||
v.AddFeature(ohm)
|
||||
ctx := context.WithValue(context.Background(), xrayKey, v)
|
||||
|
||||
stop_add_rm := false
|
||||
wg_add_rm := sync.WaitGroup{}
|
||||
addHandlers := func() {
|
||||
defer wg_add_rm.Done()
|
||||
for !stop_add_rm {
|
||||
time.Sleep(delay)
|
||||
idx := counter.Add(1)
|
||||
tag := fmt.Sprintf("%s%d", tags_prefix, idx)
|
||||
cfg := &core.OutboundHandlerConfig{
|
||||
Tag: tag,
|
||||
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
|
||||
}
|
||||
if h, err := NewHandler(ctx, cfg); err == nil {
|
||||
if err := ohm.AddHandler(ctx, h); err == nil {
|
||||
// t.Log("add handler:", tag)
|
||||
tags.Store(tag, nil)
|
||||
} else {
|
||||
t.Error("failed to add handler:", tag)
|
||||
}
|
||||
} else {
|
||||
t.Error("failed to create handler:", tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rmHandlers := func() {
|
||||
defer wg_add_rm.Done()
|
||||
for !stop_add_rm {
|
||||
time.Sleep(delay)
|
||||
tags.Range(func(key interface{}, value interface{}) bool {
|
||||
if _, ok := tags.LoadAndDelete(key); ok {
|
||||
// t.Log("remove handler:", key)
|
||||
ohm.RemoveHandler(ctx, key.(string))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
selectors := []string{tags_prefix}
|
||||
wg_get := sync.WaitGroup{}
|
||||
stop_get := false
|
||||
getTags := func() {
|
||||
defer wg_get.Done()
|
||||
for !stop_get {
|
||||
time.Sleep(delay)
|
||||
_ = ohm.Select(selectors)
|
||||
// t.Logf("get tags: %v", tag)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < threads_num; i++ {
|
||||
wg_add_rm.Add(2)
|
||||
go rmHandlers()
|
||||
go addHandlers()
|
||||
wg_get.Add(1)
|
||||
go getTags()
|
||||
}
|
||||
|
||||
time.Sleep(test_duration)
|
||||
stop_add_rm = true
|
||||
wg_add_rm.Wait()
|
||||
stop_get = true
|
||||
wg_get.Wait()
|
||||
}
|
||||
|
|
|
@ -22,14 +22,14 @@ type Manager struct {
|
|||
taggedHandler map[string]outbound.Handler
|
||||
untaggedHandlers []outbound.Handler
|
||||
running bool
|
||||
tagsCache map[string][]string
|
||||
tagsCache *sync.Map
|
||||
}
|
||||
|
||||
// New creates a new Manager.
|
||||
func New(ctx context.Context, config *proxyman.OutboundConfig) (*Manager, error) {
|
||||
m := &Manager{
|
||||
taggedHandler: make(map[string]outbound.Handler),
|
||||
tagsCache: make(map[string][]string),
|
||||
tagsCache: &sync.Map{},
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) erro
|
|||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
|
||||
m.tagsCache = make(map[string][]string)
|
||||
m.tagsCache = &sync.Map{}
|
||||
|
||||
if m.defaultHandler == nil {
|
||||
m.defaultHandler = handler
|
||||
|
@ -137,7 +137,7 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
|
|||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
|
||||
m.tagsCache = make(map[string][]string)
|
||||
m.tagsCache = &sync.Map{}
|
||||
|
||||
delete(m.taggedHandler, tag)
|
||||
if m.defaultHandler != nil && m.defaultHandler.Tag() == tag {
|
||||
|
@ -149,14 +149,15 @@ func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
|
|||
|
||||
// Select implements outbound.HandlerSelector.
|
||||
func (m *Manager) Select(selectors []string) []string {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
|
||||
key := strings.Join(selectors, ",")
|
||||
if cache, ok := m.tagsCache[key]; ok {
|
||||
return cache
|
||||
if cache, ok := m.tagsCache.Load(key); ok {
|
||||
return cache.([]string)
|
||||
}
|
||||
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
|
||||
tags := make([]string, 0, len(selectors))
|
||||
|
||||
for tag := range m.taggedHandler {
|
||||
|
@ -169,7 +170,7 @@ func (m *Manager) Select(selectors []string) []string {
|
|||
}
|
||||
|
||||
sort.Strings(tags)
|
||||
m.tagsCache[key] = tags
|
||||
m.tagsCache.Store(key, tags)
|
||||
|
||||
return tags
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue