Adjust Trojan & Socks handleUDPPayload

This commit is contained in:
RPRX 2021-01-08 06:00:51 +00:00 committed by GitHub
parent d5aeb6c545
commit fb0e517158
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 12 deletions

View file

@ -218,7 +218,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
conn.Write(udpMessage.Bytes()) conn.Write(udpMessage.Bytes())
}) })
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { inbound := session.InboundFromContext(ctx)
if inbound != nil && inbound.Source.IsValid() {
newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx)) newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
} }
@ -249,7 +250,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
currentPacketCtx := ctx currentPacketCtx := ctx
newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { if inbound != nil && inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source, From: inbound.Source,
To: destination, To: destination,

View file

@ -251,7 +251,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
udpPayload := packet.Payload udpPayload := packet.Payload
if udpPayload.UDP == nil {
udpPayload.UDP = &packet.Source udpPayload.UDP = &packet.Source
}
common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload})) common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}))
}) })
@ -274,23 +276,30 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
} }
mb2, b := buf.SplitFirst(mb) mb2, b := buf.SplitFirst(mb)
if b == nil {
continue
}
destination := *b.UDP destination := *b.UDP
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
currentPacketCtx := ctx
if inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source, From: inbound.Source,
To: destination, To: destination,
Status: log.AccessAccepted, Status: log.AccessAccepted,
Reason: "", Reason: "",
Email: user.Email, Email: user.Email,
}) })
}
newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx)) newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
if !buf.Cone || dest == nil { if !buf.Cone || dest == nil {
dest = &destination dest = &destination
} }
udpServer.Dispatch(ctx, *dest, b) // first packet udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
for _, payload := range mb2 { for _, payload := range mb2 {
udpServer.Dispatch(ctx, *dest, payload) udpServer.Dispatch(currentPacketCtx, *dest, payload)
} }
} }
} }