Restrict access to LiveKit SFU by differentiating full-access and restricted Matrix users for room creation (#67)

* add new ENV variable LIVEKIT_FULL_ACCESS_HOMESERVERS to allow different handling between full-access and restricted users

* full-access / restricted user detection

* Create LiveKit room on the SFU in case of a full-acceess user prior to issuing the JWT token

* Support full-access for all users via wildcard `*`for all homeservers

* make the wildcard '*' the default of LIVEKIT_FULL_ACCESS_HOMESERVERS to mimic the previous behaviour

* more idomatic variable nameing

* More ideomatic order for of functions in main.go
This commit is contained in:
fkwp
2025-07-29 10:34:19 +02:00
committed by GitHub
parent 16a2ccf047
commit 114f0f4560
6 changed files with 359 additions and 268 deletions

280
main.go
View File

@@ -15,11 +15,14 @@ import (
"log"
"net/http"
"os"
"slices"
"strings"
"time"
"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
lksdk "github.com/livekit/server-sdk-go/v2"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/fclient"
@@ -27,8 +30,9 @@ import (
)
type Handler struct {
key, secret, lk_url string
skipVerifyTLS bool
key, secret, lkUrl string
fullAccessHomeservers []string
skipVerifyTLS bool
}
type OpenIDTokenType struct {
@@ -48,11 +52,79 @@ type SFUResponse struct {
JWT string `json:"jwt"`
}
func exchangeOIDCToken(
func readKeySecret() (string, string) {
// We initialize keys & secrets from environment variables
key := os.Getenv("LIVEKIT_KEY")
secret := os.Getenv("LIVEKIT_SECRET")
// We initialize potential key & secret path from environment variables
keyPath := os.Getenv("LIVEKIT_KEY_FROM_FILE")
secretPath := os.Getenv("LIVEKIT_SECRET_FROM_FILE")
keySecretPath := os.Getenv("LIVEKIT_KEY_FILE")
// If keySecretPath is set we read the file and split it into two parts
// It takes over any other initialization
if keySecretPath != "" {
if keySecretBytes, err := os.ReadFile(keySecretPath); err != nil {
log.Fatal(err)
} else {
keySecrets := strings.Split(string(keySecretBytes), ":")
if len(keySecrets) != 2 {
log.Fatalf("invalid key secret file format!")
}
key = keySecrets[0]
secret = keySecrets[1]
}
} else {
// If keySecretPath is not set, we try to read the key and secret from files
// If those files are not set, we return the key & secret from the environment variables
if keyPath != "" {
if keyBytes, err := os.ReadFile(keyPath); err != nil {
log.Fatal(err)
} else {
key = string(keyBytes)
}
}
if secretPath != "" {
if secretBytes, err := os.ReadFile(secretPath); err != nil {
log.Fatal(err)
} else {
secret = string(secretBytes)
}
}
}
// remove white spaces, new lines and carriage returns
// from key and secret
return strings.Trim(key, " \r\n"), strings.Trim(secret, " \r\n")
}
func getJoinToken(apiKey, apiSecret, room, identity string) (string, error) {
at := auth.NewAccessToken(apiKey, apiSecret)
canPublish := true
canSubscribe := true
grant := &auth.VideoGrant{
RoomJoin: true,
RoomCreate: false,
CanPublish: &canPublish,
CanSubscribe: &canSubscribe,
Room: room,
}
at.SetVideoGrant(grant).
SetIdentity(identity).
SetValidFor(time.Hour)
return at.ToJWT()
}
func exchangeOpenIdUserInfo(
ctx context.Context, token OpenIDTokenType, skipVerifyTLS bool,
) (*fclient.UserInfo, error) {
if token.AccessToken == "" || token.MatrixServerName == "" {
return nil, errors.New("missing parameters in OIDC token")
return nil, errors.New("missing parameters in openid token")
}
if skipVerifyTLS {
@@ -73,6 +145,26 @@ func exchangeOIDCToken(
return &userinfo, nil
}
func (h *Handler) isFullAccessUser(matrixServerName string) bool {
// Grant full access if wildcard '*' is present as the only entry
if len(h.fullAccessHomeservers) == 1 && h.fullAccessHomeservers[0] == "*" {
return true
}
// Check if the matrixServerName is in the list of full-access homeservers
return slices.Contains(h.fullAccessHomeservers, matrixServerName)
}
func (h *Handler) prepareMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/sfu/get", h.handle)
mux.HandleFunc("/healthz", h.healthcheck)
return mux
}
func (h *Handler) healthcheck(w http.ResponseWriter, r *http.Request) {
log.Printf("Health check from %s", r.RemoteAddr)
@@ -98,8 +190,8 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
return
case "POST":
var sfu_access_request SFURequest
err := json.NewDecoder(r.Body).Decode(&sfu_access_request)
var sfuAccessRequest SFURequest
err := json.NewDecoder(r.Body).Decode(&sfuAccessRequest)
if err != nil {
log.Printf("Error decoding JSON: %v", err)
w.WriteHeader(http.StatusBadRequest)
@@ -113,7 +205,7 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
return
}
if sfu_access_request.Room == "" {
if sfuAccessRequest.Room == "" {
log.Printf("Request missing room")
w.WriteHeader(http.StatusBadRequest)
err = json.NewEncoder(w).Encode(gomatrix.RespError{
@@ -128,7 +220,7 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
// TODO: we should be sanitising the input here before using it
// e.g. only allowing `https://` URL scheme
userInfo, err := exchangeOIDCToken(r.Context(), sfu_access_request.OpenIDToken, h.skipVerifyTLS)
userInfo, err := exchangeOpenIdUserInfo(r.Context(), sfuAccessRequest.OpenIDToken, h.skipVerifyTLS)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(gomatrix.RespError{
@@ -141,10 +233,18 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
return
}
log.Printf("Got user info for %s", userInfo.Sub)
// Does the user belong to homeservers granted full access
isFullAccessUser := h.isFullAccessUser(sfuAccessRequest.OpenIDToken.MatrixServerName)
log.Printf(
"Got Matrix user info for %s (%s)",
userInfo.Sub,
map[bool]string{true: "full access", false: "restricted access"}[isFullAccessUser],
)
// TODO: is DeviceID required? If so then we should have validated at the start of the request processing
token, err := getJoinToken(h.key, h.secret, sfu_access_request.Room, userInfo.Sub+":"+sfu_access_request.DeviceID)
lkIdentity := userInfo.Sub + ":" + sfuAccessRequest.DeviceID
token, err := getJoinToken(h.key, h.secret, sfuAccessRequest.Room, lkIdentity)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(gomatrix.RespError{
@@ -157,7 +257,42 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
return
}
res := SFUResponse{URL: h.lk_url, JWT: token}
if isFullAccessUser {
roomClient := lksdk.NewRoomServiceClient(h.lkUrl, h.key, h.secret)
creationStart := time.Now().Unix()
room, err := roomClient.CreateRoom(
context.Background(), &livekit.CreateRoomRequest{
Name: sfuAccessRequest.Room,
EmptyTimeout: 5 * 60, // 5 Minutes to keep the room open if no one joins
DepartureTimeout: 20, // number of seconds to keep the room after everyone leaves
MaxParticipants: 0, // 0 == no limitation
},
)
if err != nil {
log.Printf("Unable to create room %s. Error message: %v", sfuAccessRequest.Room, err)
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(gomatrix.RespError{
ErrCode: "M_UNKNOWN",
Err: "Unable to create room on SFU",
})
if err != nil {
log.Printf("failed to encode json error message! %v", err)
}
return
}
// Log the room creation time and the user info
isNewRoom := room.GetCreationTime() >= creationStart && room.GetCreationTime() <= time.Now().Unix()
log.Printf(
"%s LiveKit room sid: %s (alias: %s) for full-access Matrix user %s (LiveKit identity: %s)",
map[bool]string{true: "Created", false: "Using"}[isNewRoom],
room.Sid, sfuAccessRequest.Room, userInfo.Sub , lkIdentity,
)
}
res := SFUResponse{URL: h.lkUrl, JWT: token}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(res)
@@ -169,61 +304,6 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
}
}
func (h *Handler) prepareMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/sfu/get", h.handle)
mux.HandleFunc("/healthz", h.healthcheck)
return mux
}
func readKeySecret() (string, string) {
// We initialize keys & secrets from environment variables
key := os.Getenv("LIVEKIT_KEY")
secret := os.Getenv("LIVEKIT_SECRET")
// We initialize potential key & secret path from environment variables
keyPath := os.Getenv("LIVEKIT_KEY_FROM_FILE")
secretPath := os.Getenv("LIVEKIT_SECRET_FROM_FILE")
keySecretPath := os.Getenv("LIVEKIT_KEY_FILE")
// If keySecretPath is set we read the file and split it into two parts
// It takes over any other initialization
if keySecretPath != "" {
if keySecretBytes, err := os.ReadFile(keySecretPath); err != nil {
log.Fatal(err)
} else {
key_secrets := strings.Split(string(keySecretBytes), ":")
if len(key_secrets) != 2 {
log.Fatalf("invalid key secret file format!")
}
key = key_secrets[0]
secret = key_secrets[1]
}
} else {
// If keySecretPath is not set, we try to read the key and secret from files
// If those files are not set, we return the key & secret from the environment variables
if keyPath != "" {
if keyBytes, err := os.ReadFile(keyPath); err != nil {
log.Fatal(err)
} else {
key = string(keyBytes)
}
}
if secretPath != "" {
if secretBytes, err := os.ReadFile(secretPath); err != nil {
log.Fatal(err)
} else {
secret = string(secretBytes)
}
}
}
return strings.Trim(key, " \r\n"), strings.Trim(secret, " \r\n")
}
func main() {
skipVerifyTLS := os.Getenv("LIVEKIT_INSECURE_SKIP_VERIFY_TLS") == "YES_I_KNOW_WHAT_I_AM_DOING"
if skipVerifyTLS {
@@ -233,47 +313,47 @@ func main() {
log.Printf("!!! WARNING !!! Use only for testing or debugging !!! WARNING !!!")
log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
}
lk_url := os.Getenv("LIVEKIT_URL")
lk_jwt_port := os.Getenv("LIVEKIT_JWT_PORT")
if lk_jwt_port == "" {
lk_jwt_port = "8080"
}
log.Printf("LIVEKIT_URL: %s, LIVEKIT_JWT_PORT: %s", lk_url, lk_jwt_port)
key, secret := readKeySecret()
lkUrl := os.Getenv("LIVEKIT_URL")
// Check if the key, secret or url are empty.
if key == "" || secret == "" || lk_url == "" {
if key == "" || secret == "" || lkUrl == "" {
log.Fatal("LIVEKIT_KEY[_FILE], LIVEKIT_SECRET[_FILE] and LIVEKIT_URL environment variables must be set")
}
fullAccessHomeservers := os.Getenv("LIVEKIT_FULL_ACCESS_HOMESERVERS")
if len(fullAccessHomeservers) == 0 {
// For backward compatibility we also check for LIVEKIT_LOCAL_HOMESERVERS
// TODO: Remove this backward compatibility in the near future.
localHomeservers := os.Getenv("LIVEKIT_LOCAL_HOMESERVERS")
if len(localHomeservers) > 0 {
log.Printf("!!! LIVEKIT_LOCAL_HOMESERVERS is deprecated, please use LIVEKIT_FULL_ACCESS_HOMESERVERS instead !!!")
fullAccessHomeservers = localHomeservers
} else {
// If no full access homeservers are set, we default to wildcard '*' to mimic the previous behavior.
// TODO: Remove defaulting to wildcard '*' (aka full-access for all users) in the near future.
log.Printf("LIVEKIT_FULL_ACCESS_HOMESERVERS not set, defaulting to wildcard (*) for full access")
fullAccessHomeservers = "*"
}
}
lkJwtPort := os.Getenv("LIVEKIT_JWT_PORT")
if lkJwtPort == "" {
lkJwtPort = "8080"
}
log.Printf("LIVEKIT_URL: %s, LIVEKIT_JWT_PORT: %s", lkUrl, lkJwtPort)
log.Printf("LIVEKIT_FULL_ACCESS_HOMESERVERS: %v", fullAccessHomeservers)
handler := &Handler{
key: key,
secret: secret,
lk_url: lk_url,
skipVerifyTLS: skipVerifyTLS,
key: key,
secret: secret,
lkUrl: lkUrl,
skipVerifyTLS: skipVerifyTLS,
fullAccessHomeservers: strings.Split(fullAccessHomeservers, ","),
}
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", lk_jwt_port), handler.prepareMux()))
}
func getJoinToken(apiKey, apiSecret, room, identity string) (string, error) {
at := auth.NewAccessToken(apiKey, apiSecret)
canPublish := true
canSubscribe := true
grant := &auth.VideoGrant{
RoomJoin: true,
RoomCreate: true,
CanPublish: &canPublish,
CanSubscribe: &canSubscribe,
Room: room,
}
at.SetVideoGrant(grant).
SetIdentity(identity).
SetValidFor(time.Hour)
return at.ToJWT()
}
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", lkJwtPort), handler.prepareMux()))
}