Allow to inject custom validator in VLESS controller (#3453)

* Make Validator an interface

* Move validator creation away from VLESS inbound controller
This commit is contained in:
Torikki 2024-09-13 17:51:26 +03:00 committed by GitHub
parent 3a8c5f38e8
commit c259e4e4a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 23 deletions

View file

@ -64,7 +64,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ
} }
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) { func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()

View file

@ -42,7 +42,7 @@ func TestRequestSerialization(t *testing.T) {
buffer := buf.StackNew() buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons)) common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
Validator := new(vless.Validator) Validator := new(vless.MemoryValidator)
Validator.Add(user) Validator.Add(user)
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator) actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) {
buffer := buf.StackNew() buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons)) common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
Validator := new(vless.Validator) Validator := new(vless.MemoryValidator)
Validator.Add(user) Validator.Add(user)
_, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator) _, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) {
buffer := buf.StackNew() buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons)) common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
Validator := new(vless.Validator) Validator := new(vless.MemoryValidator)
Validator.Add(user) Validator.Add(user)
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator) actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)

View file

@ -45,7 +45,21 @@ func init() {
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
return New(ctx, config.(*Config), dc)
c := config.(*Config)
validator := new(vless.MemoryValidator)
for _, user := range c.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := validator.Add(u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
}
return New(ctx, c, dc, validator)
})) }))
} }
@ -53,30 +67,20 @@ func init() {
type Handler struct { type Handler struct {
inboundHandlerManager feature_inbound.Manager inboundHandlerManager feature_inbound.Manager
policyManager policy.Manager policyManager policy.Manager
validator *vless.Validator validator vless.Validator
dns dns.Client dns dns.Client
fallbacks map[string]map[string]map[string]*Fallback // or nil fallbacks map[string]map[string]map[string]*Fallback // or nil
// regexps map[string]*regexp.Regexp // or nil // regexps map[string]*regexp.Regexp // or nil
} }
// New creates a new VLess inbound handler. // New creates a new VLess inbound handler.
func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) { func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
v := core.MustFromContext(ctx) v := core.MustFromContext(ctx)
handler := &Handler{ handler := &Handler{
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager), inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
validator: new(vless.Validator),
dns: dc, dns: dc,
} validator: validator,
for _, user := range config.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := handler.AddUser(ctx, u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
} }
if config.Fallbacks != nil { if config.Fallbacks != nil {

View file

@ -9,15 +9,21 @@ import (
"github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/common/uuid"
) )
// Validator stores valid VLESS users. type Validator interface {
type Validator struct { Get(id uuid.UUID) *protocol.MemoryUser
Add(u *protocol.MemoryUser) error
Del(email string) error
}
// MemoryValidator stores valid VLESS users.
type MemoryValidator struct {
// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance. // Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
email sync.Map email sync.Map
users sync.Map users sync.Map
} }
// Add a VLESS user, Email must be empty or unique. // Add a VLESS user, Email must be empty or unique.
func (v *Validator) Add(u *protocol.MemoryUser) error { func (v *MemoryValidator) Add(u *protocol.MemoryUser) error {
if u.Email != "" { if u.Email != "" {
_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u) _, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
if loaded { if loaded {
@ -29,7 +35,7 @@ func (v *Validator) Add(u *protocol.MemoryUser) error {
} }
// Del a VLESS user with a non-empty Email. // Del a VLESS user with a non-empty Email.
func (v *Validator) Del(e string) error { func (v *MemoryValidator) Del(e string) error {
if e == "" { if e == "" {
return errors.New("Email must not be empty.") return errors.New("Email must not be empty.")
} }
@ -44,7 +50,7 @@ func (v *Validator) Del(e string) error {
} }
// Get a VLESS user with UUID, nil if user doesn't exist. // Get a VLESS user with UUID, nil if user doesn't exist.
func (v *Validator) Get(id uuid.UUID) *protocol.MemoryUser { func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser {
u, _ := v.users.Load(id) u, _ := v.users.Load(id)
if u != nil { if u != nil {
return u.(*protocol.MemoryUser) return u.(*protocol.MemoryUser)