mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-14 20:53:18 +00:00
Allow to inject custom validator in VLESS controller (#3453)
* Make Validator an interface * Move validator creation away from VLESS inbound controller
This commit is contained in:
parent
3a8c5f38e8
commit
c259e4e4a6
|
@ -64,7 +64,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
|
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
|
||||||
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
|
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
|
||||||
buffer := buf.StackNew()
|
buffer := buf.StackNew()
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ func TestRequestSerialization(t *testing.T) {
|
||||||
buffer := buf.StackNew()
|
buffer := buf.StackNew()
|
||||||
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
|
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
|
||||||
|
|
||||||
Validator := new(vless.Validator)
|
Validator := new(vless.MemoryValidator)
|
||||||
Validator.Add(user)
|
Validator.Add(user)
|
||||||
|
|
||||||
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
|
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
|
||||||
|
@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
buffer := buf.StackNew()
|
buffer := buf.StackNew()
|
||||||
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
|
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
|
||||||
|
|
||||||
Validator := new(vless.Validator)
|
Validator := new(vless.MemoryValidator)
|
||||||
Validator.Add(user)
|
Validator.Add(user)
|
||||||
|
|
||||||
_, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
|
_, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
|
||||||
|
@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) {
|
||||||
buffer := buf.StackNew()
|
buffer := buf.StackNew()
|
||||||
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
|
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))
|
||||||
|
|
||||||
Validator := new(vless.Validator)
|
Validator := new(vless.MemoryValidator)
|
||||||
Validator.Add(user)
|
Validator.Add(user)
|
||||||
|
|
||||||
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
|
actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
|
||||||
|
|
|
@ -45,7 +45,21 @@ func init() {
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return New(ctx, config.(*Config), dc)
|
|
||||||
|
c := config.(*Config)
|
||||||
|
|
||||||
|
validator := new(vless.MemoryValidator)
|
||||||
|
for _, user := range c.Clients {
|
||||||
|
u, err := user.ToMemoryUser()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
|
||||||
|
}
|
||||||
|
if err := validator.Add(u); err != nil {
|
||||||
|
return nil, errors.New("failed to initiate user").Base(err).AtError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return New(ctx, c, dc, validator)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,30 +67,20 @@ func init() {
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
inboundHandlerManager feature_inbound.Manager
|
inboundHandlerManager feature_inbound.Manager
|
||||||
policyManager policy.Manager
|
policyManager policy.Manager
|
||||||
validator *vless.Validator
|
validator vless.Validator
|
||||||
dns dns.Client
|
dns dns.Client
|
||||||
fallbacks map[string]map[string]map[string]*Fallback // or nil
|
fallbacks map[string]map[string]map[string]*Fallback // or nil
|
||||||
// regexps map[string]*regexp.Regexp // or nil
|
// regexps map[string]*regexp.Regexp // or nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new VLess inbound handler.
|
// New creates a new VLess inbound handler.
|
||||||
func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
|
func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
|
||||||
v := core.MustFromContext(ctx)
|
v := core.MustFromContext(ctx)
|
||||||
handler := &Handler{
|
handler := &Handler{
|
||||||
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
|
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
|
||||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||||
validator: new(vless.Validator),
|
|
||||||
dns: dc,
|
dns: dc,
|
||||||
}
|
validator: validator,
|
||||||
|
|
||||||
for _, user := range config.Clients {
|
|
||||||
u, err := user.ToMemoryUser()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
|
|
||||||
}
|
|
||||||
if err := handler.AddUser(ctx, u); err != nil {
|
|
||||||
return nil, errors.New("failed to initiate user").Base(err).AtError()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Fallbacks != nil {
|
if config.Fallbacks != nil {
|
||||||
|
|
|
@ -9,15 +9,21 @@ import (
|
||||||
"github.com/xtls/xray-core/common/uuid"
|
"github.com/xtls/xray-core/common/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validator stores valid VLESS users.
|
type Validator interface {
|
||||||
type Validator struct {
|
Get(id uuid.UUID) *protocol.MemoryUser
|
||||||
|
Add(u *protocol.MemoryUser) error
|
||||||
|
Del(email string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryValidator stores valid VLESS users.
|
||||||
|
type MemoryValidator struct {
|
||||||
// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
|
// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
|
||||||
email sync.Map
|
email sync.Map
|
||||||
users sync.Map
|
users sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a VLESS user, Email must be empty or unique.
|
// Add a VLESS user, Email must be empty or unique.
|
||||||
func (v *Validator) Add(u *protocol.MemoryUser) error {
|
func (v *MemoryValidator) Add(u *protocol.MemoryUser) error {
|
||||||
if u.Email != "" {
|
if u.Email != "" {
|
||||||
_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
|
_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
|
||||||
if loaded {
|
if loaded {
|
||||||
|
@ -29,7 +35,7 @@ func (v *Validator) Add(u *protocol.MemoryUser) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Del a VLESS user with a non-empty Email.
|
// Del a VLESS user with a non-empty Email.
|
||||||
func (v *Validator) Del(e string) error {
|
func (v *MemoryValidator) Del(e string) error {
|
||||||
if e == "" {
|
if e == "" {
|
||||||
return errors.New("Email must not be empty.")
|
return errors.New("Email must not be empty.")
|
||||||
}
|
}
|
||||||
|
@ -44,7 +50,7 @@ func (v *Validator) Del(e string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a VLESS user with UUID, nil if user doesn't exist.
|
// Get a VLESS user with UUID, nil if user doesn't exist.
|
||||||
func (v *Validator) Get(id uuid.UUID) *protocol.MemoryUser {
|
func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser {
|
||||||
u, _ := v.users.Load(id)
|
u, _ := v.users.Load(id)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
return u.(*protocol.MemoryUser)
|
return u.(*protocol.MemoryUser)
|
||||||
|
|
Loading…
Reference in a new issue