Files
meshcore-analyzer/cmd/server/websocket.go
T
Kpa-clawbot ec0ebeda2f fix(#1793): WebSocket CheckOrigin allowlist (block cross-origin scrapers) (#1795)
## Summary

Closes the wide-open `/ws` WebSocket upgrader (`CheckOrigin: return
true`) that lets any browser origin scrape live packet data. Replaces it
with an explicit allowlist consulted from `cfg.CORSAllowedOrigins`, plus
an implicit same-origin allowance and an empty-Origin (non-browser
client) allowance.

Fixes #1793.

## Rules (`Hub.checkOrigin`)

- Empty `Origin` header → **allow** (non-browser clients; per-IP
rate/deny gating tracked separately in #1794).
- `Origin` host == request `Host` (case-insensitive) → **allow**
(same-origin).
- `Origin` matches an entry in `cfg.CORSAllowedOrigins` by exact
case-insensitive match → **allow**.
- `"*"` in `cfg.CORSAllowedOrigins` is **deliberately ignored** for
`/ws`. A startup `[ws] WARNING:` is logged once when present.
- Anything else → **reject** (gorilla returns 403).

### Deliberate divergence from CORS XHR

CORS XHR (`corsMiddleware`) still honors `"*"` for read-only
cross-origin GETs. The `/ws` upgrade does NOT, per OWASP's WebSocket
Security Cheat Sheet:

> Use an allowlist, not a denylist. Avoid wildcards or substring
matching.

—
https://cheatsheetseries.owasp.org/cheatsheets/WebSocket_Security_Cheat_Sheet.html

`"*"` on the WS path would re-open the exact CSWSH/scraping vector this
PR closes, so it is rejected with a startup warning rather than silently
honored. This intentional asymmetry is documented in the updated
`_comment_corsAllowedOrigins` in `config.example.json`.

## TDD red → green

- `e5974c6a` **RED** — adds `cmd/server/websocket_checkorigin_test.go`
with five cases; `SetAllowedOrigins` introduced as an enforcement stub
so the test compiles and fails on the assertion (CI fails on this commit
by design).
- `a4791dc3` **GREEN** — implements `Hub.checkOrigin`, wires
`SetAllowedOrigins` from `main.go`, updates the config example. All
tests pass.

## Tests added (`cmd/server/websocket_checkorigin_test.go`)

- `TestCheckOriginRejectsForeignOrigin` — foreign Origin → 403
- `TestCheckOriginAllowsEmptyOrigin` — non-browser client → 101
- `TestCheckOriginAllowsSameHost` — same-origin → 101
- `TestCheckOriginAllowsAllowlistedOrigin` — exact allowlist match → 101
- `TestCheckOriginWildcardDoesNotAllowForeignOrigin` — `"*"` in
allowlist still rejects foreign origin → 403

## Files changed

- `cmd/server/websocket.go` — `Hub.allowedOrigins`, `SetAllowedOrigins`,
`checkOrigin`, wired into `Upgrader.CheckOrigin`.
- `cmd/server/main.go` — `hub.SetAllowedOrigins(cfg.CORSAllowedOrigins)`
at the single call site.
- `cmd/server/websocket_checkorigin_test.go` — new test file.
- `config.example.json` — updated `_comment_corsAllowedOrigins` to
document `/ws` gating and the `"*"` divergence.

## Out of scope (follow-up)

- **#1794** — per-IP rate limit / deny list / connection cap for
non-browser clients (which still bypass Origin because they don't send
one). Layered defense; not in this PR.

## Verification

- `go test ./cmd/server/...` — all server tests pass locally (574s).
- Preflight clean (`bash
~/.openclaw/skills/pr-preflight/scripts/run-all.sh origin/master`).

---------

Co-authored-by: openclaw-bot <bot@openclaw.local>
2026-06-29 05:48:58 -07:00

330 lines
8.3 KiB
Go

package main
import (
"encoding/json"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Hub manages WebSocket clients and broadcasts.
type Hub struct {
mu sync.RWMutex
clients map[*Client]bool
upgrader websocket.Upgrader
allowedOrigins []string // exact-match allowlist for /ws CheckOrigin (see SetAllowedOrigins)
}
// SetAllowedOrigins configures the exact-match origin allowlist consulted by
// the WebSocket upgrader's CheckOrigin. The "*" wildcard is deliberately NOT
// honored here (it IS honored by the HTTP CORS middleware): OWASP's
// WebSocket Security Cheat Sheet recommends an explicit allowlist for CSWSH
// defense. If "*" appears in the slice, it is ignored and a startup WARN is
// logged once per call.
//
// See: https://cheatsheetseries.owasp.org/cheatsheets/WebSocket_Security_Cheat_Sheet.html
func (h *Hub) SetAllowedOrigins(origins []string) {
h.mu.Lock()
defer h.mu.Unlock()
h.allowedOrigins = append(h.allowedOrigins[:0], origins...)
for _, o := range origins {
if o == "*" {
log.Println(`[ws] WARNING: CORSAllowedOrigins contains "*" — CORS allows any origin for XHR, but /ws upgrade enforces explicit allowlist only (OWASP CSWSH guidance). Add specific origins to allow cross-origin WebSocket clients.`)
break
}
}
}
// checkOrigin is the gorilla/websocket Upgrader.CheckOrigin hook. Rules:
// - empty Origin header → allow (non-browser client; rate-limit / IP gate
// is handled separately, see #1794).
// - Origin host == request Host (same-origin) → allow.
// - Origin in allowedOrigins by exact case-insensitive match → allow.
// - "*" in allowedOrigins is ignored (see SetAllowedOrigins).
// - anything else → reject (gorilla returns 403).
func (h *Hub) checkOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
if strings.EqualFold(u.Host, r.Host) {
return true
}
h.mu.RLock()
allowed := h.allowedOrigins
h.mu.RUnlock()
for _, o := range allowed {
if o == "*" {
continue // deliberately not honored — see SetAllowedOrigins
}
if strings.EqualFold(o, origin) {
return true
}
}
return false
}
// Client is a single WebSocket connection.
type Client struct {
conn *websocket.Conn
send chan []byte
closeOnce sync.Once
}
func NewHub() *Hub {
h := &Hub{
clients: make(map[*Client]bool),
}
h.upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 4096,
CheckOrigin: h.checkOrigin,
}
return h
}
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
func (h *Hub) Register(c *Client) {
h.mu.Lock()
h.clients[c] = true
h.mu.Unlock()
log.Printf("[ws] client connected (%d total)", h.ClientCount())
}
func (h *Hub) Unregister(c *Client) {
h.mu.Lock()
if _, ok := h.clients[c]; ok {
delete(h.clients, c)
c.closeOnce.Do(func() { close(c.send) })
}
h.mu.Unlock()
log.Printf("[ws] client disconnected (%d total)", h.ClientCount())
}
// Close gracefully disconnects all WebSocket clients.
func (h *Hub) Close() {
h.mu.Lock()
for c := range h.clients {
c.conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"),
time.Now().Add(3*time.Second),
)
c.closeOnce.Do(func() { close(c.send) })
delete(h.clients, c)
}
h.mu.Unlock()
log.Println("[ws] all clients disconnected")
}
// Broadcast sends a message to all connected clients.
func (h *Hub) Broadcast(msg interface{}) {
data, err := json.Marshal(msg)
if err != nil {
log.Printf("[ws] marshal error: %v", err)
return
}
h.mu.RLock()
defer h.mu.RUnlock()
for c := range h.clients {
select {
case c.send <- data:
default:
// Client buffer full — drop
}
}
}
// ServeWS handles the WebSocket upgrade and runs the client.
func (h *Hub) ServeWS(w http.ResponseWriter, r *http.Request) {
conn, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("[ws] upgrade error: %v", err)
return
}
client := &Client{
conn: conn,
send: make(chan []byte, 256),
}
h.Register(client)
go client.writePump()
go client.readPump(h)
}
// wsOrStatic upgrades WebSocket requests at any path, serves static files otherwise.
func wsOrStatic(hub *Hub, static http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
hub.ServeWS(w, r)
return
}
static.ServeHTTP(w, r)
})
}
func (c *Client) readPump(hub *Hub) {
defer func() {
hub.Unregister(c)
c.conn.Close()
}()
c.conn.SetReadLimit(512)
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, _, err := c.conn.ReadMessage()
if err != nil {
break
}
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// Poller watches for new transmissions in SQLite and broadcasts them.
type Poller struct {
db *DB
hub *Hub
store *PacketStore // optional: if set, new transmissions are ingested into memory
interval time.Duration
stop chan struct{}
}
func NewPoller(db *DB, hub *Hub, interval time.Duration) *Poller {
return &Poller{db: db, hub: hub, interval: interval, stop: make(chan struct{})}
}
func (p *Poller) Start() {
lastID := p.db.GetMaxTransmissionID()
lastObsID := p.db.GetMaxObservationID()
// If the store already loaded data, use its max IDs as a floor.
// This prevents replaying the entire DB when the DB query fails
// (e.g., corrupted DB returns 0 from COALESCE).
if p.store != nil {
if storeMax := p.store.MaxTransmissionID(); storeMax > lastID {
lastID = storeMax
}
if storeMaxObs := p.store.MaxObservationID(); storeMaxObs > lastObsID {
lastObsID = storeMaxObs
}
}
log.Printf("[poller] starting from transmission ID %d, obs ID %d, interval %v", lastID, lastObsID, p.interval)
ticker := time.NewTicker(p.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if p.store != nil {
// Ingest new transmissions into in-memory store and broadcast
newTxs, newMax := p.store.IngestNewFromDB(lastID, 100)
if newMax > lastID {
lastID = newMax
}
// Ingest new observations for existing transmissions (fixes #174)
nextObsID := lastObsID
if err := p.db.conn.QueryRow(`
SELECT COALESCE(MAX(id), ?) FROM (
SELECT id FROM observations
WHERE id > ?
ORDER BY id ASC
LIMIT 500
)`, lastObsID, lastObsID).Scan(&nextObsID); err != nil {
nextObsID = lastObsID
}
newObs := p.store.IngestNewObservations(lastObsID, 500)
if nextObsID > lastObsID {
lastObsID = nextObsID
}
if len(newTxs) > 0 {
log.Printf("[broadcast] sending %d packets to %d clients (lastID now %d)", len(newTxs), p.hub.ClientCount(), lastID)
}
for _, tx := range newTxs {
p.hub.Broadcast(WSMessage{
Type: "packet",
Data: tx,
})
}
for _, obs := range newObs {
p.hub.Broadcast(WSMessage{
Type: "packet",
Data: obs,
})
}
} else {
// Fallback: direct DB query (used when store is nil, e.g. tests)
newTxs, err := p.db.GetNewTransmissionsSince(lastID, 100)
if err != nil {
log.Printf("[poller] error: %v", err)
continue
}
for _, tx := range newTxs {
id, _ := tx["id"].(int)
if id > lastID {
lastID = id
}
// Copy packet fields for the nested packet (avoids circular ref)
pkt := make(map[string]interface{}, len(tx))
for k, v := range tx {
pkt[k] = v
}
tx["packet"] = pkt
p.hub.Broadcast(WSMessage{
Type: "packet",
Data: tx,
})
}
}
case <-p.stop:
return
}
}
}
func (p *Poller) Stop() {
close(p.stop)
}