HTTPUpgrade send headers with specified capitalization (#3430)

* Fix HTTPUpgrade header capitalization

* Chore

* Remove excess host headers

Chore : change httpupgrade header "upgrade" to "Upgrade" #3435
This commit is contained in:
风扇滑翔翼 2024-06-06 22:31:56 +08:00
parent f8ec93dfdd
commit 3654c0d710
3 changed files with 26 additions and 63 deletions

View File

@ -208,8 +208,10 @@ func (c *HttpUpgradeConfig) Build() (proto.Message, error) {
// Host priority: Host field > headers field > address. // Host priority: Host field > headers field > address.
if c.Host == "" && c.Headers["host"] != "" { if c.Host == "" && c.Headers["host"] != "" {
c.Host = c.Headers["host"] c.Host = c.Headers["host"]
delete(c.Headers,"host")
} else if c.Host == "" && c.Headers["Host"] != "" { } else if c.Host == "" && c.Headers["Host"] != "" {
c.Host = c.Headers["Host"] c.Host = c.Headers["Host"]
delete(c.Headers,"Host")
} }
config := &httpupgrade.Config{ config := &httpupgrade.Config{
Path: path, Path: path,

View File

@ -69,69 +69,23 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
requestURL.Scheme = "http" requestURL.Scheme = "http"
} }
var req *http.Request = nil requestURL.Host = dest.NetAddr()
requestURL.Path = transportConfiguration.GetNormalizedPath()
req := &http.Request{
Method: http.MethodGet,
URL: &requestURL,
Host: transportConfiguration.Host,
Header: make(http.Header),
}
for key, value := range transportConfiguration.Header {
AddHeader(req.Header, key, value)
}
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
if len(transportConfiguration.Header) == 0 { err = req.Write(conn)
requestURL.Host = dest.NetAddr() if err != nil {
requestURL.Path = transportConfiguration.GetNormalizedPath() return nil, err
req = &http.Request{
Method: http.MethodGet,
URL: &requestURL,
Host: transportConfiguration.Host,
Header: make(http.Header),
}
req.Header.Set("Connection", "upgrade")
req.Header.Set("Upgrade", "websocket")
err = req.Write(conn)
if err != nil {
return nil, err
}
} else {
var headersBuilder strings.Builder
headersBuilder.WriteString("GET ")
headersBuilder.WriteString(transportConfiguration.GetNormalizedPath())
headersBuilder.WriteString(" HTTP/1.1\r\n")
hasConnectionHeader := false
hasUpgradeHeader := false
hasHostHeader := false
for key, value := range transportConfiguration.Header {
if strings.ToLower(key) == "connection" {
hasConnectionHeader = true
}
if strings.ToLower(key) == "upgrade" {
hasUpgradeHeader = true
}
if strings.ToLower(key) == "host" {
hasHostHeader = true
}
headersBuilder.WriteString(key)
headersBuilder.WriteString(": ")
headersBuilder.WriteString(value)
headersBuilder.WriteString("\r\n")
}
if !hasConnectionHeader {
headersBuilder.WriteString("Connection: upgrade\r\n")
}
if !hasUpgradeHeader {
headersBuilder.WriteString("Upgrade: websocket\r\n")
}
if !hasHostHeader {
headersBuilder.WriteString("Host: ")
headersBuilder.WriteString(transportConfiguration.Host)
headersBuilder.WriteString("\r\n")
}
headersBuilder.WriteString("\r\n")
_, err = conn.Write([]byte(headersBuilder.String()))
if err != nil {
return nil, err
}
} }
connRF := &ConnRF{ connRF := &ConnRF{
@ -150,6 +104,13 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
return connRF, nil return connRF, nil
} }
//http.Header.Add() will convert headers to MIME header format.
//Some people don't like this because they want to send "Web*S*ocket".
//So we add a simple function to replace that method.
func AddHeader(header http.Header, key, value string) {
header[key] = append(header[key], value)
}
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))

View File

@ -62,7 +62,7 @@ func (s *server) Handle(conn net.Conn) (stat.Connection, error) {
ProtoMinor: 1, ProtoMinor: 1,
Header: http.Header{}, Header: http.Header{},
} }
resp.Header.Set("Connection", "upgrade") resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "websocket") resp.Header.Set("Upgrade", "websocket")
err = resp.Write(conn) err = resp.Write(conn)
if err != nil { if err != nil {