Browse Source
The WebSocket connection is kept open in order to use it to notify shutdowns.pull/1377/head
7 changed files with 202 additions and 57 deletions
@ -0,0 +1,114 @@
@@ -0,0 +1,114 @@
|
||||
// Package websocket provides WebSocket connectivity.
|
||||
package websocket |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"time" |
||||
|
||||
"github.com/gorilla/websocket" |
||||
) |
||||
|
||||
var ( |
||||
pingInterval = 30 * time.Second |
||||
pingTimeout = 5 * time.Second |
||||
writeTimeout = 2 * time.Second |
||||
) |
||||
|
||||
var upgrader = websocket.Upgrader{ |
||||
CheckOrigin: func(r *http.Request) bool { |
||||
return true |
||||
}, |
||||
} |
||||
|
||||
// ServerConn is a server-side WebSocket connection with automatic, periodic ping / pong.
|
||||
type ServerConn struct { |
||||
wc *websocket.Conn |
||||
|
||||
// in
|
||||
terminate chan struct{} |
||||
write chan []byte |
||||
|
||||
// out
|
||||
writeErr chan error |
||||
} |
||||
|
||||
// NewServerConn allocates a ServerConn.
|
||||
func NewServerConn(w http.ResponseWriter, req *http.Request) (*ServerConn, error) { |
||||
wc, err := upgrader.Upgrade(w, req, nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c := &ServerConn{ |
||||
wc: wc, |
||||
terminate: make(chan struct{}), |
||||
write: make(chan []byte), |
||||
writeErr: make(chan error), |
||||
} |
||||
|
||||
go c.run() |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
// Close closes a ServerConn.
|
||||
func (c *ServerConn) Close() { |
||||
c.wc.Close() |
||||
close(c.terminate) |
||||
} |
||||
|
||||
// RemoteAddr returns the remote address.
|
||||
func (c *ServerConn) RemoteAddr() net.Addr { |
||||
return c.wc.RemoteAddr() |
||||
} |
||||
|
||||
func (c *ServerConn) run() { |
||||
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout)) |
||||
|
||||
c.wc.SetPongHandler(func(string) error { |
||||
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout)) |
||||
return nil |
||||
}) |
||||
|
||||
pingTicker := time.NewTicker(pingInterval) |
||||
defer pingTicker.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case byts := <-c.write: |
||||
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout)) |
||||
err := c.wc.WriteMessage(websocket.TextMessage, byts) |
||||
c.writeErr <- err |
||||
|
||||
case <-pingTicker.C: |
||||
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout)) |
||||
c.wc.WriteMessage(websocket.PingMessage, nil) |
||||
|
||||
case <-c.terminate: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// ReadJSON reads a JSON object.
|
||||
func (c *ServerConn) ReadJSON(in interface{}) error { |
||||
return c.wc.ReadJSON(in) |
||||
} |
||||
|
||||
// WriteJSON writes a JSON object.
|
||||
func (c *ServerConn) WriteJSON(in interface{}) error { |
||||
byts, err := json.Marshal(in) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
select { |
||||
case c.write <- byts: |
||||
return <-c.writeErr |
||||
case <-c.terminate: |
||||
return fmt.Errorf("terminated") |
||||
} |
||||
} |
@ -0,0 +1,55 @@
@@ -0,0 +1,55 @@
|
||||
package websocket |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
"net/http" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/gorilla/websocket" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestServerConn(t *testing.T) { |
||||
pingReceived := make(chan struct{}) |
||||
pingInterval = 100 * time.Millisecond |
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) { |
||||
c, err := NewServerConn(w, r) |
||||
require.NoError(t, err) |
||||
defer c.Close() |
||||
|
||||
err = c.WriteJSON("testing") |
||||
require.NoError(t, err) |
||||
|
||||
<-pingReceived |
||||
} |
||||
|
||||
ln, err := net.Listen("tcp", "localhost:6344") |
||||
require.NoError(t, err) |
||||
defer ln.Close() |
||||
|
||||
s := &http.Server{Handler: http.HandlerFunc(handler)} |
||||
go s.Serve(ln) |
||||
defer s.Shutdown(context.Background()) |
||||
|
||||
c, res, err := websocket.DefaultDialer.Dial("ws://localhost:6344/", nil) |
||||
require.NoError(t, err) |
||||
defer res.Body.Close() |
||||
defer c.Close() |
||||
|
||||
c.SetPingHandler(func(msg string) error { |
||||
close(pingReceived) |
||||
return nil |
||||
}) |
||||
|
||||
var msg string |
||||
err = c.ReadJSON(&msg) |
||||
require.NoError(t, err) |
||||
require.Equal(t, "testing", msg) |
||||
|
||||
c.ReadMessage() |
||||
|
||||
<-pingReceived |
||||
} |
Loading…
Reference in new issue