Allow to inject custom validator in VLESS controller (#3453)

* Make Validator an interface

* Move validator creation away from VLESS inbound controller
This commit is contained in:
Torikki 2024-09-13 17:51:26 +03:00 committed by GitHub
parent 3a8c5f38e8
commit c259e4e4a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 23 deletions

View file

@ -64,7 +64,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ
}
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
buffer := buf.StackNew()
defer buffer.Release()

View file

@ -42,7 +42,7 @@ func TestRequestSerialization(t *testing.T) {
buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
Validator := new(vless.Validator)
Validator := new(vless.MemoryValidator)
Validator.Add(user)
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) {
buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
Validator := new(vless.Validator)
Validator := new(vless.MemoryValidator)
Validator.Add(user)
_, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) {
buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
Validator := new(vless.Validator)
Validator := new(vless.MemoryValidator)
Validator.Add(user)
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)

View file

@ -45,7 +45,21 @@ func init() {
}); err != nil {
return nil, err
}
return New(ctx, config.(*Config), dc)
c := config.(*Config)
validator := new(vless.MemoryValidator)
for _, user := range c.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := validator.Add(u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
}
return New(ctx, c, dc, validator)
}))
}
@ -53,30 +67,20 @@ func init() {
type Handler struct {
inboundHandlerManager feature_inbound.Manager
policyManager policy.Manager
validator *vless.Validator
validator vless.Validator
dns dns.Client
fallbacks map[string]map[string]map[string]*Fallback // or nil
// regexps map[string]*regexp.Regexp // or nil
}
// New creates a new VLess inbound handler.
func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
v := core.MustFromContext(ctx)
handler := &Handler{
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
validator: new(vless.Validator),
dns: dc,
}
for _, user := range config.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := handler.AddUser(ctx, u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
validator: validator,
}
if config.Fallbacks != nil {

View file

@ -9,15 +9,21 @@ import (
"github.com/xtls/xray-core/common/uuid"
)
// Validator stores valid VLESS users.
type Validator struct {
type Validator interface {
Get(id uuid.UUID) *protocol.MemoryUser
Add(u *protocol.MemoryUser) error
Del(email string) error
}
// MemoryValidator stores valid VLESS users.
type MemoryValidator struct {
// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
email sync.Map
users sync.Map
}
// Add a VLESS user, Email must be empty or unique.
func (v *Validator) Add(u *protocol.MemoryUser) error {
func (v *MemoryValidator) Add(u *protocol.MemoryUser) error {
if u.Email != "" {
_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
if loaded {
@ -29,7 +35,7 @@ func (v *Validator) Add(u *protocol.MemoryUser) error {
}
// Del a VLESS user with a non-empty Email.
func (v *Validator) Del(e string) error {
func (v *MemoryValidator) Del(e string) error {
if e == "" {
return errors.New("Email must not be empty.")
}
@ -44,7 +50,7 @@ func (v *Validator) Del(e string) error {
}
// Get a VLESS user with UUID, nil if user doesn't exist.
func (v *Validator) Get(id uuid.UUID) *protocol.MemoryUser {
func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser {
u, _ := v.users.Load(id)
if u != nil {
return u.(*protocol.MemoryUser)