Browse Source
The WebSocket connection is kept open in order to use it to notify shutdowns.pull/1377/head
7 changed files with 202 additions and 57 deletions
@ -0,0 +1,114 @@ |
|||||||
|
// Package websocket provides WebSocket connectivity.
|
||||||
|
package websocket |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"net" |
||||||
|
"net/http" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/gorilla/websocket" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
pingInterval = 30 * time.Second |
||||||
|
pingTimeout = 5 * time.Second |
||||||
|
writeTimeout = 2 * time.Second |
||||||
|
) |
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{ |
||||||
|
CheckOrigin: func(r *http.Request) bool { |
||||||
|
return true |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
// ServerConn is a server-side WebSocket connection with automatic, periodic ping / pong.
|
||||||
|
type ServerConn struct { |
||||||
|
wc *websocket.Conn |
||||||
|
|
||||||
|
// in
|
||||||
|
terminate chan struct{} |
||||||
|
write chan []byte |
||||||
|
|
||||||
|
// out
|
||||||
|
writeErr chan error |
||||||
|
} |
||||||
|
|
||||||
|
// NewServerConn allocates a ServerConn.
|
||||||
|
func NewServerConn(w http.ResponseWriter, req *http.Request) (*ServerConn, error) { |
||||||
|
wc, err := upgrader.Upgrade(w, req, nil) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c := &ServerConn{ |
||||||
|
wc: wc, |
||||||
|
terminate: make(chan struct{}), |
||||||
|
write: make(chan []byte), |
||||||
|
writeErr: make(chan error), |
||||||
|
} |
||||||
|
|
||||||
|
go c.run() |
||||||
|
|
||||||
|
return c, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes a ServerConn.
|
||||||
|
func (c *ServerConn) Close() { |
||||||
|
c.wc.Close() |
||||||
|
close(c.terminate) |
||||||
|
} |
||||||
|
|
||||||
|
// RemoteAddr returns the remote address.
|
||||||
|
func (c *ServerConn) RemoteAddr() net.Addr { |
||||||
|
return c.wc.RemoteAddr() |
||||||
|
} |
||||||
|
|
||||||
|
func (c *ServerConn) run() { |
||||||
|
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout)) |
||||||
|
|
||||||
|
c.wc.SetPongHandler(func(string) error { |
||||||
|
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout)) |
||||||
|
return nil |
||||||
|
}) |
||||||
|
|
||||||
|
pingTicker := time.NewTicker(pingInterval) |
||||||
|
defer pingTicker.Stop() |
||||||
|
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case byts := <-c.write: |
||||||
|
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout)) |
||||||
|
err := c.wc.WriteMessage(websocket.TextMessage, byts) |
||||||
|
c.writeErr <- err |
||||||
|
|
||||||
|
case <-pingTicker.C: |
||||||
|
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout)) |
||||||
|
c.wc.WriteMessage(websocket.PingMessage, nil) |
||||||
|
|
||||||
|
case <-c.terminate: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ReadJSON reads a JSON object.
|
||||||
|
func (c *ServerConn) ReadJSON(in interface{}) error { |
||||||
|
return c.wc.ReadJSON(in) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteJSON writes a JSON object.
|
||||||
|
func (c *ServerConn) WriteJSON(in interface{}) error { |
||||||
|
byts, err := json.Marshal(in) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
select { |
||||||
|
case c.write <- byts: |
||||||
|
return <-c.writeErr |
||||||
|
case <-c.terminate: |
||||||
|
return fmt.Errorf("terminated") |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,55 @@ |
|||||||
|
package websocket |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"net" |
||||||
|
"net/http" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/gorilla/websocket" |
||||||
|
"github.com/stretchr/testify/require" |
||||||
|
) |
||||||
|
|
||||||
|
func TestServerConn(t *testing.T) { |
||||||
|
pingReceived := make(chan struct{}) |
||||||
|
pingInterval = 100 * time.Millisecond |
||||||
|
|
||||||
|
handler := func(w http.ResponseWriter, r *http.Request) { |
||||||
|
c, err := NewServerConn(w, r) |
||||||
|
require.NoError(t, err) |
||||||
|
defer c.Close() |
||||||
|
|
||||||
|
err = c.WriteJSON("testing") |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
<-pingReceived |
||||||
|
} |
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "localhost:6344") |
||||||
|
require.NoError(t, err) |
||||||
|
defer ln.Close() |
||||||
|
|
||||||
|
s := &http.Server{Handler: http.HandlerFunc(handler)} |
||||||
|
go s.Serve(ln) |
||||||
|
defer s.Shutdown(context.Background()) |
||||||
|
|
||||||
|
c, res, err := websocket.DefaultDialer.Dial("ws://localhost:6344/", nil) |
||||||
|
require.NoError(t, err) |
||||||
|
defer res.Body.Close() |
||||||
|
defer c.Close() |
||||||
|
|
||||||
|
c.SetPingHandler(func(msg string) error { |
||||||
|
close(pingReceived) |
||||||
|
return nil |
||||||
|
}) |
||||||
|
|
||||||
|
var msg string |
||||||
|
err = c.ReadJSON(&msg) |
||||||
|
require.NoError(t, err) |
||||||
|
require.Equal(t, "testing", msg) |
||||||
|
|
||||||
|
c.ReadMessage() |
||||||
|
|
||||||
|
<-pingReceived |
||||||
|
} |
Loading…
Reference in new issue