session.py 3.41 KB
Newer Older
1
import json
2
import logging
3
4
5
6
from uuid import uuid4
from time import sleep

from werkzeug.datastructures import CallbackDict
7
from flask import session
8
9
10
from flask.sessions import SessionMixin, SessionInterface


11
12
13
logger = logging.getLogger(__name__)


14
15
16
class RedisSession(CallbackDict, SessionMixin):

    def __init__(self, sid, initial=None):
17
18
19
        if initial is None:
            initial = {'data_tasks': [], 'analytic_tasks': [], 'subsets': []}

20
21
22
23
24
25
26
27
28
        def on_update(self):
            self.modified = True
        CallbackDict.__init__(self, initial, on_update)
        self.sid = sid
        self.permanent = True
        self.modified = False

class RedisSessionInterface(SessionInterface):

29
    def __init__(self, redis, app):
30
31
        self.redis = redis

32
33
34
35
36
37
38
39
40
41
42
43
44
        @app.teardown_request
        def teardown_request(exc=None) -> None:
            """Release session lock whatever happens.
            :param exc: Unhandled exception that might have been thrown.
            """
            try:
                app.session_interface.release_lock(session.sid)
            except AttributeError:
                logger.warning(
                    "Attempted to release session lock but no session id was "
                    "found. This happens only during testing and should never "
                    "happen in production mode.")

45
    def acquire_lock(self, sid, request_id):
46
47
48
49
        if request_id is None:
            return
        lock_id = self.redis.get(name='session:{}:lock'.format(sid))
        if lock_id == request_id:
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            return
        while self.redis.getset(name='session:{}:lock'.format(sid),
                                value=request_id):
            sleep(0.1)
        self.redis.setex(name='session:{}:lock'.format(sid),
                         value=request_id, time=10)

    def release_lock(self, sid):
        self.redis.delete('session:{}:lock'.format(sid))

    def open_session(self, app, request):
        request_id = request.environ.get("FLASK_REQUEST_ID")
        sid = request.cookies.get(app.session_cookie_name)
        if not sid:
            sid = str(uuid4())
            self.acquire_lock(sid, request_id)
            return RedisSession(sid=sid)
        self.acquire_lock(sid, request_id)
        session_data = self.redis.get('session:{}'.format(sid))
        if session_data is not None:
            session_data = json.loads(session_data)
            return RedisSession(sid=sid, initial=session_data)
        return RedisSession(sid=sid)

    def save_session(self, app, session, response):
        path = self.get_cookie_path(app)
        domain = self.get_cookie_domain(app)
        if not session:
            if session.modified:
                self.redis.delete('session:{}'.format(session.sid))
                response.delete_cookie(app.session_cookie_name,
                                       domain=domain, path=path)
            return
        session_expiration_time = app.config['PERMANENT_SESSION_LIFETIME']
        cookie_expiration_time = self.get_expiration_time(app, session)
        serialzed_session_data = json.dumps(dict(session))
        self.redis.setex(name='session:{}'.format(session.sid),
                         time=session_expiration_time,
                         value=serialzed_session_data)
        response.set_cookie(key=app.session_cookie_name, value=session.sid,
                            expires=cookie_expiration_time, httponly=True,
                            domain=domain)