Add simple session management using redis
This commit is contained in:
26
atst/app.py
26
atst/app.py
@@ -2,6 +2,7 @@ import os
|
||||
from configparser import ConfigParser
|
||||
import tornado.web
|
||||
from tornado.web import url
|
||||
from redis import StrictRedis
|
||||
|
||||
from atst.handlers.main import MainHandler
|
||||
from atst.handlers.home import Home
|
||||
@@ -12,6 +13,7 @@ from atst.handlers.request_new import RequestNew
|
||||
from atst.handlers.dev import Dev
|
||||
from atst.home import home
|
||||
from atst.api_client import ApiClient
|
||||
from atst.sessions import RedisSessions
|
||||
|
||||
ENV = os.getenv("TORNADO_ENV", "dev")
|
||||
|
||||
@@ -20,7 +22,12 @@ def make_app(config, deps, **kwargs):
|
||||
|
||||
routes = [
|
||||
url(r"/", Home, {"page": "login"}, name="main"),
|
||||
url(r"/login", Login, {"authnid_client": deps["authnid_client"]}, name="login"),
|
||||
url(
|
||||
r"/login",
|
||||
Login,
|
||||
{"sessions": deps["sessions"], "authnid_client": deps["authnid_client"]},
|
||||
name="login",
|
||||
),
|
||||
url(r"/home", MainHandler, {"page": "home"}, name="home"),
|
||||
url(
|
||||
r"/workspaces/blank",
|
||||
@@ -64,7 +71,14 @@ def make_app(config, deps, **kwargs):
|
||||
]
|
||||
|
||||
if not ENV == "production":
|
||||
routes += [url(r"/login-dev", Dev, {"action": "login"}, name="dev-login")]
|
||||
routes += [
|
||||
url(
|
||||
r"/login-dev",
|
||||
Dev,
|
||||
{"action": "login", "sessions": deps["sessions"]},
|
||||
name="dev-login",
|
||||
)
|
||||
]
|
||||
|
||||
app = tornado.web.Application(
|
||||
routes,
|
||||
@@ -76,12 +90,17 @@ def make_app(config, deps, **kwargs):
|
||||
**kwargs,
|
||||
)
|
||||
app.config = config
|
||||
app.sessions = deps["sessions"]
|
||||
return app
|
||||
|
||||
|
||||
def make_deps(config):
|
||||
# we do not want to do SSL verify services in test and development
|
||||
validate_cert = ENV == "production"
|
||||
redis_client = StrictRedis.from_url(
|
||||
config["default"]["REDIS_URI"], decode_responses=True
|
||||
)
|
||||
|
||||
return {
|
||||
"authz_client": ApiClient(
|
||||
config["default"]["AUTHZ_BASE_URL"],
|
||||
@@ -98,6 +117,9 @@ def make_deps(config):
|
||||
api_version="v1",
|
||||
validate_cert=validate_cert,
|
||||
),
|
||||
"sessions": RedisSessions(
|
||||
redis_client, config["default"]["SESSION_TTL_SECONDS"]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from webassets import Environment, Bundle
|
||||
import tornado.web
|
||||
from atst.home import home
|
||||
from atst.sessions import SessionNotFoundError
|
||||
|
||||
assets = Environment(directory=home.child("scss"), url="/static")
|
||||
css = Bundle(
|
||||
@@ -21,17 +22,19 @@ class BaseHandler(tornado.web.RequestHandler):
|
||||
ns.update(helpers)
|
||||
return ns
|
||||
|
||||
def login(self, user):
|
||||
session_id = self.sessions.start_session(user)
|
||||
self.set_secure_cookie("atat", session_id)
|
||||
self.redirect("/home")
|
||||
|
||||
def get_current_user(self):
|
||||
if self.get_secure_cookie("atst"):
|
||||
return {
|
||||
"id": "9cb348f0-8102-4962-88c4-dac8180c904c",
|
||||
"email": "fake.user@mail.com",
|
||||
"first_name": "Fake",
|
||||
"last_name": "User",
|
||||
}
|
||||
cookie = self.get_secure_cookie("atat")
|
||||
if cookie:
|
||||
try:
|
||||
session = self.application.sessions.get_session(cookie)
|
||||
except SessionNotFoundError:
|
||||
self.redirect("/login")
|
||||
else:
|
||||
return None
|
||||
|
||||
# this is a temporary implementation until we have real sessions
|
||||
def _start_session(self):
|
||||
self.set_secure_cookie("atst", "valid-user-session")
|
||||
return session["user"]
|
||||
|
||||
@@ -2,13 +2,10 @@ from atst.handler import BaseHandler
|
||||
|
||||
|
||||
class Dev(BaseHandler):
|
||||
def initialize(self, action):
|
||||
def initialize(self, action, sessions):
|
||||
self.action = action
|
||||
self.sessions = sessions
|
||||
|
||||
def get(self):
|
||||
if self.action == "login":
|
||||
self._login()
|
||||
|
||||
def _login(self):
|
||||
self._start_session()
|
||||
self.redirect("/home")
|
||||
user = {"id": "164497f6-c1ea-4f42-a5ef-101da278c012"}
|
||||
self.login(user)
|
||||
|
||||
@@ -3,34 +3,35 @@ from atst.handler import BaseHandler
|
||||
|
||||
|
||||
class Login(BaseHandler):
|
||||
def initialize(self, authnid_client):
|
||||
def initialize(self, authnid_client, sessions):
|
||||
self.authnid_client = authnid_client
|
||||
self.sessions = sessions
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def get(self):
|
||||
token = self.get_query_argument("bearer-token")
|
||||
if token:
|
||||
valid = yield self._validate_login_token(token)
|
||||
if valid:
|
||||
self._start_session()
|
||||
self.redirect("/home")
|
||||
return
|
||||
user = yield self._fetch_user_info(token)
|
||||
if user:
|
||||
self.login(user)
|
||||
else:
|
||||
self.write_error(401)
|
||||
|
||||
url = self.get_login_url()
|
||||
self.redirect(url)
|
||||
return
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _validate_login_token(self, token):
|
||||
def _fetch_user_info(self, token):
|
||||
try:
|
||||
response = yield self.authnid_client.post(
|
||||
"/validate", json={"token": token}
|
||||
)
|
||||
return response.code == 200
|
||||
if response.code == 200:
|
||||
return response.json["user"]
|
||||
|
||||
except tornado.httpclient.HTTPError as error:
|
||||
if error.response.code == 401:
|
||||
return False
|
||||
return None
|
||||
|
||||
else:
|
||||
raise error
|
||||
|
||||
71
atst/sessions.py
Normal file
71
atst/sessions.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from uuid import uuid4
|
||||
import json
|
||||
from redis import exceptions
|
||||
|
||||
|
||||
class SessionStorageError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SessionNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Sessions(object):
|
||||
def start_session(self, user):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_session(self, session_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_session_id(self):
|
||||
return str(uuid4())
|
||||
|
||||
def build_session_dict(self, user=None):
|
||||
return {"user": user or {}}
|
||||
|
||||
|
||||
class DictSessions(Sessions):
|
||||
def __init__(self):
|
||||
self.sessions = {}
|
||||
|
||||
def start_session(self, user):
|
||||
session_id = self.generate_session_id()
|
||||
self.sessions[session_id] = self.build_session_dict(user=user)
|
||||
return session_id
|
||||
|
||||
def get_session(self, session_id):
|
||||
try:
|
||||
session = self.sessions[session_id]
|
||||
except KeyError:
|
||||
raise SessionNotFoundError
|
||||
|
||||
return session
|
||||
|
||||
|
||||
class RedisSessions(Sessions):
|
||||
def __init__(self, redis, ttl_seconds):
|
||||
self.redis = redis
|
||||
self.ttl_seconds = ttl_seconds
|
||||
|
||||
def start_session(self, user):
|
||||
session_id = self.generate_session_id()
|
||||
session_dict = self.build_session_dict(user=user)
|
||||
session_serialized = json.dumps(session_dict)
|
||||
try:
|
||||
self.redis.setex(session_id, self.ttl_seconds, session_serialized)
|
||||
except exceptions.ConnectionError:
|
||||
raise SessionStorageError
|
||||
return session_id
|
||||
|
||||
def get_session(self, session_id):
|
||||
try:
|
||||
session_serialized = self.redis.get(session_id)
|
||||
except exceptions.ConnectionError:
|
||||
raise
|
||||
|
||||
if session_serialized:
|
||||
self.redis.expire(session_id, self.ttl_seconds)
|
||||
return json.loads(session_serialized)
|
||||
else:
|
||||
raise SessionNotFoundError
|
||||
Reference in New Issue
Block a user