diff --git a/Dockerfile b/Dockerfile index 89418c7..0aaa0cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,7 @@ ADD media /src/media ADD originals /src/originals Add playlists /src/playlists ADD transcodes /src/transcodes +ADD users /src/users Add ytdlp /src/ytdlp ADD go.mod /src/. diff --git a/handlers.go b/handlers.go index c8f6a48..90a026d 100644 --- a/handlers.go +++ b/handlers.go @@ -13,7 +13,6 @@ import ( "time" "github.com/labstack/echo/v4" - "golang.org/x/crypto/bcrypt" "gorm.io/gorm" "ytdlp-site/config" @@ -23,6 +22,7 @@ import ( "ytdlp-site/originals" "ytdlp-site/playlists" "ytdlp-site/transcodes" + "ytdlp-site/users" "ytdlp-site/ytdlp" ) @@ -37,7 +37,7 @@ func registerPostHandler(c echo.Context) error { username := c.FormValue("username") password := c.FormValue("password") - err := CreateUser(db, username, password) + err := users.Create(db, username, password) if err != nil { return c.String(http.StatusInternalServerError, "Error creating user") @@ -47,66 +47,13 @@ func registerPostHandler(c echo.Context) error { } func homeHandler(c echo.Context) error { - - // redirect to /videos if logged in - session, err := store.Get(c.Request(), "session") - if err == nil { - _, ok := session.Values["user_id"] - if ok { - fmt.Println("homeHandler: session contains user_id. Redirect to /video") - return c.Redirect(http.StatusSeeOther, "/videos") - } - } - - return c.Render(http.StatusOK, "home.html", - map[string]interface{}{ - "Footer": handlers.MakeFooter(), - }) -} - -func loginHandler(c echo.Context) error { - return c.Render(http.StatusOK, "login.html", nil) -} - -func loginPostHandler(c echo.Context) error { - username := c.FormValue("username") - password := c.FormValue("password") - - var user User - if err := db.Where("username = ?", username).First(&user).Error; err != nil { - return c.String(http.StatusUnauthorized, "Invalid credentials") - } - - if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - return c.String(http.StatusUnauthorized, "Invalid credentials") - } - - session, err := store.Get(c.Request(), "session") + _, err := handlers.GetUser(c) if err != nil { - return c.String(http.StatusInternalServerError, "Unable to retrieve session") + return c.Redirect(http.StatusSeeOther, "/login") + } else { + fmt.Println("homeHandler: session contains user_id. Redirect to /video") + return c.Redirect(http.StatusSeeOther, "/videos") } - session.Values["user_id"] = user.ID - err = session.Save(c.Request(), c.Response().Writer) - - if err != nil { - return c.String(http.StatusInternalServerError, "Unable to save session") - } - - session, _ = store.Get(c.Request(), "session") - _, ok := session.Values["user_id"] - if !ok { - return c.String(http.StatusInternalServerError, "user_id was not saved as expected") - } - - fmt.Println("loginPostHandler: redirect to /download") - return c.Redirect(http.StatusSeeOther, "/download") -} - -func logoutHandler(c echo.Context) error { - session, _ := store.Get(c.Request(), "session") - delete(session.Values, "user_id") - session.Save(c.Request(), c.Response().Writer) - return c.Redirect(http.StatusSeeOther, "/login") } func downloadHandler(c echo.Context) error { diff --git a/handlers/init.go b/handlers/init.go index 16608b0..30b37f2 100644 --- a/handlers/init.go +++ b/handlers/init.go @@ -1,12 +1,34 @@ package handlers -import "github.com/sirupsen/logrus" +import ( + "ytdlp-site/config" + + "github.com/gorilla/sessions" + "github.com/sirupsen/logrus" +) var log *logrus.Logger +var store *sessions.CookieStore func Init(logger *logrus.Logger) error { log = logger.WithFields(logrus.Fields{ "component": "handlers", }).Logger + + // create the cookie store + key, err := config.GetSessionAuthKey() + if err != nil { + return err + } + store = sessions.NewCookieStore(key) + store.Options = &sessions.Options{ + Path: "/", + MaxAge: 30 * 24 * 60 * 60, // seconds + HttpOnly: true, + Secure: config.GetSecure(), + } + return nil } + +func Fini() {} diff --git a/handlers/login.go b/handlers/login.go new file mode 100644 index 0000000..af5cfb8 --- /dev/null +++ b/handlers/login.go @@ -0,0 +1,58 @@ +package handlers + +import ( + "fmt" + "net/http" + "ytdlp-site/database" + "ytdlp-site/users" + + "github.com/labstack/echo/v4" + "golang.org/x/crypto/bcrypt" +) + +func LoginPost(c echo.Context) error { + username := c.FormValue("username") + password := c.FormValue("password") + + db := database.Get() + + var user users.User + if err := db.Where("username = ?", username).First(&user).Error; err != nil { + return c.String(http.StatusUnauthorized, "Invalid credentials") + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { + return c.String(http.StatusUnauthorized, "Invalid credentials") + } + + session, err := store.Get(c.Request(), "session") + if err != nil { + return c.String(http.StatusInternalServerError, "Unable to retrieve session") + } + session.Values["user_id"] = user.ID + err = session.Save(c.Request(), c.Response().Writer) + + if err != nil { + return c.String(http.StatusInternalServerError, "Unable to save session") + } + + session, _ = store.Get(c.Request(), "session") + _, ok := session.Values["user_id"] + if !ok { + return c.String(http.StatusInternalServerError, "user_id was not saved as expected") + } + + fmt.Println("loginPostHandler: redirect to /download") + return c.Redirect(http.StatusSeeOther, "/download") +} + +func LoginGet(c echo.Context) error { + return c.Render(http.StatusOK, "login.html", nil) +} + +func LogoutGet(c echo.Context) error { + session, _ := store.Get(c.Request(), "session") + delete(session.Values, "user_id") + session.Save(c.Request(), c.Response().Writer) + return c.Redirect(http.StatusSeeOther, "/login") +} diff --git a/middleware.go b/handlers/middleware.go similarity index 80% rename from middleware.go rename to handlers/middleware.go index 348bf3b..96cece8 100644 --- a/middleware.go +++ b/handlers/middleware.go @@ -1,16 +1,13 @@ -package main +package handlers import ( "fmt" "net/http" - "github.com/gorilla/sessions" "github.com/labstack/echo/v4" ) -var store *sessions.CookieStore - -func authMiddleware(next echo.HandlerFunc) echo.HandlerFunc { +func AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { session, err := store.Get(c.Request(), "session") if err != nil { diff --git a/handlers/session.go b/handlers/session.go new file mode 100644 index 0000000..80a9265 --- /dev/null +++ b/handlers/session.go @@ -0,0 +1,26 @@ +package handlers + +import ( + "fmt" + + "github.com/labstack/echo/v4" +) + +type User struct { + Id uint +} + +func GetUser(c echo.Context) (User, error) { + session, err := store.Get(c.Request(), "session") + if err == nil { + val, ok := session.Values["user_id"] + if ok { + return User{Id: val.(uint)}, nil + } else { + return User{}, fmt.Errorf("user_id not in session") + } + } else { + return User{}, fmt.Errorf("couldn't retureve session from store") + } + +} diff --git a/handlers/videos.go b/handlers/videos.go new file mode 100644 index 0000000..681e052 --- /dev/null +++ b/handlers/videos.go @@ -0,0 +1,53 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "ytdlp-site/originals" + + "github.com/labstack/echo/v4" +) + +func VideosEvents(c echo.Context) error { + + user, err := GetUser(c) + if err != nil { + return err + } + + req := c.Request() + res := c.Response() + + // Set headers for SSE + res.Header().Set(echo.HeaderContentType, "text/event-stream") + res.Header().Set("Cache-Control", "no-cache") + res.Header().Set("Connection", "keep-alive") + + // Create a channel to signal client disconnect + done := req.Context().Done() + + q := originals.Subscribe(user.Id) + defer originals.Unsubscribe(user.Id, q) + + // Send SSE messages + for { + select { + case <-done: + return nil + default: + event := <-q.Ch + + jsonData, err := json.Marshal(event) + if err != nil { + return err + } + + msg := fmt.Sprintf("data: %s\n\n", jsonData) + _, err = res.Write([]byte(msg)) + if err != nil { + return err + } + res.Flush() + } + } +} diff --git a/main.go b/main.go index 6348af1..e4bf201 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "path/filepath" "time" - "github.com/gorilla/sessions" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "gorm.io/driver/sqlite" @@ -24,6 +23,7 @@ import ( "ytdlp-site/originals" "ytdlp-site/playlists" "ytdlp-site/transcodes" + "ytdlp-site/users" "ytdlp-site/ytdlp" ) @@ -31,7 +31,7 @@ var db *gorm.DB func ensureAdminAccount(db *gorm.DB) error { - var user User + var user users.User if err := db.Where("username = ?", "admin").First(&user).Error; err != nil { // no such user @@ -40,7 +40,7 @@ func ensureAdminAccount(db *gorm.DB) error { return err } - err = CreateUser(db, "admin", password) + err = users.Create(db, "admin", password) if err != nil { return err } @@ -96,10 +96,15 @@ func main() { // Migrate the schema db.AutoMigrate(&originals.Original{}, &playlists.Playlist{}, &media.Video{}, - &media.Audio{}, &User{}, &TempURL{}, &transcodes.Transcode{}) + &media.Audio{}, &users.User{}, &TempURL{}, &transcodes.Transcode{}) database.Init(db, log) defer database.Fini() + err = handlers.Init(log) + if err != nil { + panic(fmt.Sprintf("%v", err)) + } + defer handlers.Fini() go PeriodicCleanup() @@ -109,13 +114,6 @@ func main() { panic(fmt.Sprintf("failed to create admin user: %v", err)) } - // create the cookie store - key, err := config.GetSessionAuthKey() - if err != nil { - panic(fmt.Sprintf("%v", err)) - } - store = sessions.NewCookieStore(key) - // Initialize Echo e := echo.New() @@ -131,46 +129,37 @@ func main() { // Routes e.GET("/", homeHandler) - e.GET("/login", loginHandler) - e.POST("/login", loginPostHandler) + e.GET("/login", handlers.LoginGet) + e.POST("/login", handlers.LoginPost) // e.GET("/register", registerHandler) // e.POST("/register", registerPostHandler) - e.GET("/logout", logoutHandler) - e.GET("/download", downloadHandler, authMiddleware) - e.POST("/download", downloadPostHandler, authMiddleware) - e.GET("/videos", videosHandler, authMiddleware) - e.GET("/video/:id", videoHandler, authMiddleware) - e.POST("/video/:id/restart", videoRestartHandler, authMiddleware) - e.POST("/video/:id/delete", deleteOriginalHandler, authMiddleware) + e.GET("/logout", handlers.LogoutGet) + e.GET("/download", downloadHandler, handlers.AuthMiddleware) + e.POST("/download", downloadPostHandler, handlers.AuthMiddleware) + e.GET("/videos", videosHandler, handlers.AuthMiddleware) + e.GET("/video/:id", videoHandler, handlers.AuthMiddleware) + e.POST("/video/:id/restart", videoRestartHandler, handlers.AuthMiddleware) + e.POST("/video/:id/delete", deleteOriginalHandler, handlers.AuthMiddleware) e.GET("/temp/:token", tempHandler) - e.POST("/video/:id/process", processHandler, authMiddleware) - e.POST("/video/:id/toggle_watched", handlers.ToggleWatched, authMiddleware) - e.POST("/delete_video/:id", deleteVideoHandler, authMiddleware) - e.POST("/delete_audio/:id", deleteAudioHandler, authMiddleware) - e.POST("/transcode_to_video/:id", transcodeToVideoHandler, authMiddleware) - e.POST("/transcode_to_audio/:id", transcodeToAudioHandler, authMiddleware) - e.GET("/status", handlers.StatusGet, authMiddleware) + e.POST("/video/:id/process", processHandler, handlers.AuthMiddleware) + e.POST("/video/:id/toggle_watched", handlers.ToggleWatched, handlers.AuthMiddleware) + e.POST("/delete_video/:id", deleteVideoHandler, handlers.AuthMiddleware) + e.POST("/delete_audio/:id", deleteAudioHandler, handlers.AuthMiddleware) + e.POST("/transcode_to_video/:id", transcodeToVideoHandler, handlers.AuthMiddleware) + e.POST("/transcode_to_audio/:id", transcodeToAudioHandler, handlers.AuthMiddleware) + e.GET("/status", handlers.StatusGet, handlers.AuthMiddleware) - e.GET("/p/:id", playlistHandler, authMiddleware) - e.POST("/p/:id/delete", deletePlaylistHandler, authMiddleware) + e.GET("/p/:id", playlistHandler, handlers.AuthMiddleware) + e.POST("/p/:id/delete", deletePlaylistHandler, handlers.AuthMiddleware) dataGroup := e.Group("/data") - dataGroup.Use(authMiddleware) + dataGroup.Use(handlers.AuthMiddleware) dataGroup.Static("/", config.GetDataDir()) staticGroup := e.Group("/static") - staticGroup.Use(authMiddleware) + staticGroup.Use(handlers.AuthMiddleware) staticGroup.Static("/", "static") - secure := config.GetSecure() - - store.Options = &sessions.Options{ - Path: "/", - MaxAge: 30 * 24 * 60 * 60, // seconds - HttpOnly: true, - Secure: secure, - } - // tidy up the transcodes database log.Debug("tidy transcodes database...") cleanupTranscodes() diff --git a/models.go b/models.go index 9c40fbc..9fb6b49 100644 --- a/models.go +++ b/models.go @@ -8,16 +8,8 @@ import ( "ytdlp-site/originals" "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" ) -type User struct { - gorm.Model - Username string `gorm:"unique"` - Password string -} - type TempURL struct { Token string `gorm:"uniqueIndex"` FilePath string @@ -37,15 +29,6 @@ type DownloadManager struct { mutex sync.RWMutex } -func CreateUser(db *gorm.DB, username, password string) error { - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - user := User{Username: username, Password: string(hashedPassword)} - if err := db.Create(&user).Error; err != nil { - return err - } - return nil -} - func SetOriginalStatus(id uint, status originals.Status) error { return db.Model(&originals.Original{}).Where("id = ?", id).Update("status", status).Error } diff --git a/originals/originals.go b/originals/originals.go index 96afc2a..e29b5c4 100644 --- a/originals/originals.go +++ b/originals/originals.go @@ -4,6 +4,7 @@ import ( "ytdlp-site/database" "ytdlp-site/transcodes" + "github.com/google/uuid" "gorm.io/gorm" ) @@ -58,5 +59,49 @@ func SetStatusTranscodingOrCompleted(id uint) error { log.Debugln("no transcodes for original", id) return SetStatus(id, StatusCompleted) } - +} + +type Event struct { + VideoId uint + Status Status +} + +type Queue struct { + id uuid.UUID + Ch chan Event +} + +func newQueue() *Queue { + return &Queue{ + id: uuid.Must(uuid.NewV7()), + Ch: make(chan Event), + } +} + +var listeners map[uint][]*Queue + +func Subscribe(userId uint) *Queue { + _, ok := listeners[userId] + if !ok { + listeners[userId] = make([]*Queue, 0) + } + q := newQueue() + listeners[userId] = append(listeners[userId], q) + return q +} + +func Unsubscribe(userId uint, q *Queue) { + + qs, ok := listeners[userId] + if !ok { + return + } + + newQs := []*Queue{} + for _, oldQ := range qs { + if oldQ != q { + newQs = append(newQs, oldQ) + } + } + listeners[userId] = newQs } diff --git a/users/user.go b/users/user.go new file mode 100644 index 0000000..b508802 --- /dev/null +++ b/users/user.go @@ -0,0 +1,21 @@ +package users + +import ( + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +type User struct { + gorm.Model + Username string `gorm:"unique"` + Password string +} + +func Create(db *gorm.DB, username, password string) error { + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + user := User{Username: username, Password: string(hashedPassword)} + if err := db.Create(&user).Error; err != nil { + return err + } + return nil +}