diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 908a95d3..3ebb5326 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -208,8 +208,10 @@ func (c *HttpUpgradeConfig) Build() (proto.Message, error) { // Host priority: Host field > headers field > address. if c.Host == "" && c.Headers["host"] != "" { c.Host = c.Headers["host"] + delete(c.Headers,"host") } else if c.Host == "" && c.Headers["Host"] != "" { c.Host = c.Headers["Host"] + delete(c.Headers,"Host") } config := &httpupgrade.Config{ Path: path, diff --git a/transport/internet/httpupgrade/dialer.go b/transport/internet/httpupgrade/dialer.go index 45c30506..d1a8bb0d 100644 --- a/transport/internet/httpupgrade/dialer.go +++ b/transport/internet/httpupgrade/dialer.go @@ -69,69 +69,23 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * requestURL.Scheme = "http" } - var req *http.Request = nil + requestURL.Host = dest.NetAddr() + requestURL.Path = transportConfiguration.GetNormalizedPath() + req := &http.Request{ + Method: http.MethodGet, + URL: &requestURL, + Host: transportConfiguration.Host, + Header: make(http.Header), + } + for key, value := range transportConfiguration.Header { + AddHeader(req.Header, key, value) + } + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") - if len(transportConfiguration.Header) == 0 { - requestURL.Host = dest.NetAddr() - requestURL.Path = transportConfiguration.GetNormalizedPath() - req = &http.Request{ - Method: http.MethodGet, - URL: &requestURL, - Host: transportConfiguration.Host, - Header: make(http.Header), - } - - req.Header.Set("Connection", "upgrade") - req.Header.Set("Upgrade", "websocket") - - err = req.Write(conn) - if err != nil { - return nil, err - } - } else { - var headersBuilder strings.Builder - - headersBuilder.WriteString("GET ") - headersBuilder.WriteString(transportConfiguration.GetNormalizedPath()) - headersBuilder.WriteString(" HTTP/1.1\r\n") - hasConnectionHeader := false - hasUpgradeHeader := false - hasHostHeader := false - for key, value := range transportConfiguration.Header { - if strings.ToLower(key) == "connection" { - hasConnectionHeader = true - } - if strings.ToLower(key) == "upgrade" { - hasUpgradeHeader = true - } - if strings.ToLower(key) == "host" { - hasHostHeader = true - } - headersBuilder.WriteString(key) - headersBuilder.WriteString(": ") - headersBuilder.WriteString(value) - headersBuilder.WriteString("\r\n") - } - - if !hasConnectionHeader { - headersBuilder.WriteString("Connection: upgrade\r\n") - } - - if !hasUpgradeHeader { - headersBuilder.WriteString("Upgrade: websocket\r\n") - } - - if !hasHostHeader { - headersBuilder.WriteString("Host: ") - headersBuilder.WriteString(transportConfiguration.Host) - headersBuilder.WriteString("\r\n") - } - - headersBuilder.WriteString("\r\n") - _, err = conn.Write([]byte(headersBuilder.String())) - if err != nil { - return nil, err - } + err = req.Write(conn) + if err != nil { + return nil, err } connRF := &ConnRF{ @@ -150,6 +104,13 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * return connRF, nil } +//http.Header.Add() will convert headers to MIME header format. +//Some people don't like this because they want to send "Web*S*ocket". +//So we add a simple function to replace that method. +func AddHeader(header http.Header, key, value string) { + header[key] = append(header[key], value) +} + func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) diff --git a/transport/internet/httpupgrade/hub.go b/transport/internet/httpupgrade/hub.go index 5c3ccbd3..7f873d7b 100644 --- a/transport/internet/httpupgrade/hub.go +++ b/transport/internet/httpupgrade/hub.go @@ -62,7 +62,7 @@ func (s *server) Handle(conn net.Conn) (stat.Connection, error) { ProtoMinor: 1, Header: http.Header{}, } - resp.Header.Set("Connection", "upgrade") + resp.Header.Set("Connection", "Upgrade") resp.Header.Set("Upgrade", "websocket") err = resp.Write(conn) if err != nil {