Refactor: *net.UDPAddr -> *net.Destination

https://t.me/projectXray/111998
This commit is contained in:
RPRX 2020-12-28 09:40:28 +00:00 committed by GitHub
parent 6bcac6cb10
commit 13ad3fddf6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 50 deletions

View file

@ -2,9 +2,9 @@ package buf
import ( import (
"io" "io"
"net"
"github.com/xtls/xray-core/common/bytespool" "github.com/xtls/xray-core/common/bytespool"
"github.com/xtls/xray-core/common/net"
) )
const ( const (
@ -21,7 +21,7 @@ type Buffer struct {
v []byte v []byte
start int32 start int32
end int32 end int32
UDP *net.UDPAddr UDP *net.Destination
} }
// New creates a Buffer with 0 length and 2K capacity. // New creates a Buffer with 0 length and 2K capacity.
@ -49,6 +49,7 @@ func (b *Buffer) Release() {
b.v = nil b.v = nil
b.Clear() b.Clear()
pool.Put(p) pool.Put(p)
b.UDP = nil
} }
// Clear clears the content of the buffer, results an empty buffer with // Clear clears the content of the buffer, results an empty buffer with

View file

@ -149,7 +149,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if destination.Network == net.Network_TCP { if destination.Network == net.Network_TCP {
writer = buf.NewWriter(conn) writer = buf.NewWriter(conn)
} else { } else {
writer = NewPacketWriter(conn) writer = NewPacketWriter(conn, h, ctx)
} }
if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
@ -215,14 +215,18 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
return nil, err return nil, err
} }
b.Resize(0, int32(n)) b.Resize(0, int32(n))
b.UDP = d.(*net.UDPAddr) b.UDP = &net.Destination{
Address: net.IPAddress(d.(*net.UDPAddr).IP),
Port: net.Port(d.(*net.UDPAddr).Port),
Network: net.Network_UDP,
}
if r.Counter != nil { if r.Counter != nil {
r.Counter.Add(int64(n)) r.Counter.Add(int64(n))
} }
return buf.MultiBuffer{b}, nil return buf.MultiBuffer{b}, nil
} }
func NewPacketWriter(conn net.Conn) buf.Writer { func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context) buf.Writer {
iConn := conn iConn := conn
statConn, ok := iConn.(*internet.StatCouterConnection) statConn, ok := iConn.(*internet.StatCouterConnection)
if ok { if ok {
@ -236,6 +240,8 @@ func NewPacketWriter(conn net.Conn) buf.Writer {
return &PacketWriter{ return &PacketWriter{
PacketConnWrapper: c, PacketConnWrapper: c,
Counter: counter, Counter: counter,
Handler: h,
Context: ctx,
} }
} }
return &buf.SequentialWriter{Writer: conn} return &buf.SequentialWriter{Writer: conn}
@ -244,6 +250,8 @@ func NewPacketWriter(conn net.Conn) buf.Writer {
type PacketWriter struct { type PacketWriter struct {
*internet.PacketConnWrapper *internet.PacketConnWrapper
stats.Counter stats.Counter
*Handler
context.Context
} }
func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
@ -256,7 +264,18 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
var n int var n int
var err error var err error
if b.UDP != nil { if b.UDP != nil {
n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), b.UDP) if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
if ip != nil {
b.UDP.Address = ip
}
}
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
if destAddr == nil {
b.Release()
continue
}
n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
} else { } else {
n, err = w.PacketConnWrapper.Write(b.Bytes()) n, err = w.PacketConnWrapper.Write(b.Bytes())
} }

View file

@ -235,10 +235,8 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer.Release() buffer.Release()
return nil, err return nil, err
} }
payload.UDP = &net.UDPAddr{ dest := u.Destination()
IP: u.Address.IP(), payload.UDP = &dest
Port: int(u.Port),
}
return buf.MultiBuffer{payload}, nil return buf.MultiBuffer{payload}, nil
} }
@ -254,18 +252,15 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if b == nil { if b == nil {
break break
} }
var packet *buf.Buffer request := w.Request
var err error
if b.UDP != nil { if b.UDP != nil {
request := &protocol.RequestHeader{ request = &protocol.RequestHeader{
User: w.Request.User, User: w.Request.User,
Address: net.IPAddress(b.UDP.IP), Address: b.UDP.Address,
Port: net.Port(b.UDP.Port), Port: b.UDP.Port,
} }
packet, err = EncodeUDPPacket(request, b.Bytes())
} else {
packet, err = EncodeUDPPacket(w.Request, b.Bytes())
} }
packet, err := EncodeUDPPacket(request, b.Bytes())
b.Release() b.Release()
if err != nil { if err != nil {
buf.ReleaseMulti(mb) buf.ReleaseMulti(mb)

View file

@ -81,8 +81,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
if payload.UDP != nil { if payload.UDP != nil {
request = &protocol.RequestHeader{ request = &protocol.RequestHeader{
User: request.User, User: request.User,
Address: net.IPAddress(payload.UDP.IP), Address: payload.UDP.Address,
Port: net.Port(payload.UDP.Port), Port: payload.UDP.Port,
} }
} }
@ -128,25 +128,24 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
continue continue
} }
destination := request.Destination()
currentPacketCtx := ctx currentPacketCtx := ctx
if inbound.Source.IsValid() { if inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source, From: inbound.Source,
To: request.Destination(), To: destination,
Status: log.AccessAccepted, Status: log.AccessAccepted,
Reason: "", Reason: "",
Email: request.User.Email, Email: request.User.Email,
}) })
} }
newError("tunnelling request to ", request.Destination()).WriteToLog(session.ExportIDToError(currentPacketCtx)) newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx))
data.UDP = &net.UDPAddr{ data.UDP = &destination
IP: request.Address.IP(),
Port: int(request.Port),
}
if dest.Network == 0 { if dest.Network == 0 {
dest = request.Destination() // JUST FOLLOW THE FIREST PACKET dest = request.Destination() // JUST FOLLOW THE FIRST PACKET
} }
currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)

View file

@ -202,8 +202,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
if payload.UDP != nil { if payload.UDP != nil {
request = &protocol.RequestHeader{ request = &protocol.RequestHeader{
User: request.User, User: request.User,
Address: net.IPAddress(payload.UDP.IP), Address: payload.UDP.Address,
Port: net.Port(payload.UDP.Port), Port: payload.UDP.Port,
} }
} }
@ -244,24 +244,24 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
payload.Release() payload.Release()
continue continue
} }
destination := request.Destination()
currentPacketCtx := ctx currentPacketCtx := ctx
newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source, From: inbound.Source,
To: request.Destination(), To: destination,
Status: log.AccessAccepted, Status: log.AccessAccepted,
Reason: "", Reason: "",
}) })
} }
payload.UDP = &net.UDPAddr{ payload.UDP = &destination
IP: request.Address.IP(),
Port: int(request.Port),
}
if dest.Network == 0 { if dest.Network == 0 {
dest = request.Destination() // JUST FOLLOW THE FIREST PACKET dest = destination // JUST FOLLOW THE FIRST PACKET
} }
currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)

View file

@ -134,12 +134,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if b == nil { if b == nil {
break break
} }
target := w.Target target := &w.Target
if b.UDP != nil { if b.UDP != nil {
target.Address = net.IPAddress(b.UDP.IP) target = b.UDP
target.Port = net.Port(b.UDP.Port)
} }
if _, err := w.writePacket(b.Bytes(), target); err != nil { if _, err := w.writePacket(b.Bytes(), *target); err != nil {
buf.ReleaseMulti(mb) buf.ReleaseMulti(mb)
return err return err
} }
@ -155,12 +154,11 @@ func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net
if b == nil { if b == nil {
break break
} }
source := dest source := &dest
if b.UDP != nil { if b.UDP != nil {
source.Address = net.IPAddress(b.UDP.IP) source = b.UDP
source.Port = net.Port(b.UDP.Port)
} }
if _, err := w.writePacket(b.Bytes(), source); err != nil { if _, err := w.writePacket(b.Bytes(), *source); err != nil {
buf.ReleaseMulti(mb) buf.ReleaseMulti(mb)
return err return err
} }
@ -312,10 +310,7 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
} }
b := buf.New() b := buf.New()
b.UDP = &net.UDPAddr{ b.UDP = &dest
IP: addr.IP(),
Port: int(port.Value()),
}
mb = append(mb, b) mb = append(mb, b)
n, err := b.ReadFullFrom(r, int32(length)) n, err := b.ReadFullFrom(r, int32(length))
if err != nil { if err != nil {

View file

@ -281,7 +281,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx))
if dest.Network == 0 { if dest.Network == 0 {
dest = p.Target // JUST FOLLOW THE FIREST PACKET dest = p.Target // JUST FOLLOW THE FIRST PACKET
} }
for _, b := range p.Buffer { for _, b := range p.Buffer {

View file

@ -66,7 +66,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *c
cancel() cancel()
v.RemoveRay(dest) v.RemoveRay(dest)
} }
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4) timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
link, _ := v.dispatcher.Dispatch(ctx, dest) link, _ := v.dispatcher.Dispatch(ctx, dest)
entry := &connEntry{ entry := &connEntry{
link: link, link: link,