flask-login源码分析
flask-login 源码分析。
要点
印象比较深的,有以下几点。可以应用在自己的项目中。
- signal的使用(blinker)
- user_unauthorized 与 unauthorized_callback 的使用场景差别
- redirect_url 在
session和request.url的两种实现 - LocalProxy(current_user用到 werkzeug#ContextLocals)
signal 的使用
先看看用到 signal 的地方。
def unauthorized(self):
user_unauthorized.send(current_app._get_current_object())
if self.unauthorized_callback:
return self.unauthorized_callback()
if request.blueprint in self.blueprint_login_views:
login_view = self.blueprint_login_views[request.blueprint]
else:
login_view = self.login_view
if not login_view:
abort(401)
if self.login_message:
if self.localize_callback is not None:
flash(self.localize_callback(self.login_message),
category=self.login_message_category)
else:
flash(self.login_message, category=self.login_message_category)
config = current_app.config
if config.get('USE_SESSION_FOR_NEXT', USE_SESSION_FOR_NEXT):
login_url = expand_login_view(login_view)
session['next'] = make_next_param(login_url, request.url)
redirect_url = make_login_url(login_view)
else:
redirect_url = make_login_url(login_view, next_url=request.url)
return redirect(redirect_url)这是在 login_required 装饰器中调用的一段代码。
def login_required(func):
@wraps(func)
def decorated_view(*args, **kwargs):
if request.method in EXEMPT_METHODS:
return func(*args, **kwargs)
elif current_app.login_manager._login_disabled:
return func(*args, **kwargs)
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
return func(*args, **kwargs)
return decorated_view就是:
login_required检查用户是否登录状态。如果不是,就执行unauthorized方法unauthorized发出user_unauthorized信号。执行注册在这里的receiver(user_unauthorized.send(current_app._get_current_object()))- 之后判断开发者是否注册了
unauthorized_callback,如果注册了,就执行开发者注册的流程(替代默认流程)。否则执行默认流程。
再看看这个 user_unauthroized signal 的 send 方法,这是 blinker 提供的。
# blinker/base.py
def send(self, *sender, **kwargs):
# Using '*sender' rather than 'sender=None' allows 'sender' to be
# used as a keyword argument- i.e. it's an invisible name in the
# function signature.
if len(sender) == 0:
sender = None
elif len(sender) > 1:
raise TypeError('send() accepts only one positional argument, '
'%s given' % len(sender))
else:
sender = sender[0]
if not self.receivers:
return []
else:
return [(receiver, receiver(sender, **kwargs))
for receiver in self.receivers_for(sender)]一目了然啊。原来这个 signal 的 send 方法,就是把连接到这个信号上的所有接收器,都执行一次,然后把执行结果(接收方法,与方法运行结果)作为一个数组,返回给 send 的调用者。send 的第一个参数,就是信号接收器收到的 sender,即发送者。
真是方便。以后自己在需要设计 hook 的地方,也可以直接用 blinker 封装好的 signal、receiver、Namespace等组件。以前我是这么做的:
before_hook = None
def register_before_hook(func):
before_hook = func
def life():
before_hook()
do()
after_hook()要是用 blinker,就可以:
# signals.py 注册
_signals = Namespace()
before_life = _signals.signal('before_life')
# live.py 使用
def life():
before_life.send('<life object>')
...这样实现就优美多了。
user_unauthorized 与 unauthorized_callback 的使用场景差别
看这块代码
def unauthorized(self):
user_unauthorized.send(current_app._get_current_object())
if self.unauthorized_callback:
return self.unauthorized_callback()很清楚了。先执行接收器方法,再执行回调。回调最多只能有一个,而接收器可以定义任意个。
其实这样看来,接收器是完全可以替代掉 callback 方法的。
那作者为什么要先使用信号,再使用回调?我猜是回调可能更容易让使用者理解。
一般情况开发者定义回调就够用了。但是如果框架提供者,还想留给开发者更大的灵活性,比如与其它框架耦合的时候,其它框架可能也需要处理这个认证失败的信号,这样用 SIGNAL 来处理,就更灵活。
如果只用signal,开发者通过 signal 机制实现 callback,那开发者注册到这个信号之后,开发者可能并没注意到还有别人也注册到了这个信号,从而出现预期之外的事情,这可能不是框架设计者想看到的。。
所以作者又增加了callback,以在概念上更清晰的给开发者提示。
redirect_url 在 session 和 request.url 的两种实现
这块相对比较简单。直接看源码。
def xxx():
# ...
if config.get('USE_SESSION_FOR_NEXT', USE_SESSION_FOR_NEXT):
login_url = expand_login_view(login_view)
session['next'] = make_next_param(login_url, request.url)
redirect_url = make_login_url(login_view)
else:
redirect_url = make_login_url(login_view, next_url=request.url)
return redirect(redirect_url)逻辑是:
- 判断
next_url是从session中计算,还是根据当前url计算。 至于为什么要支持两种方式,比如用户直接访问的LOGIN页,这个时候当前URL并不是用户希望跳回的页面。当然还有其它情况。需要更精细的处理返回逻辑时,就不能用当前URL来简单的计算。 - 计算 新的
login?next=xxxx.
再看看 make_login_url 的算法。
def login_url(login_view, next_url=None, next_field='next'):
'''
Creates a URL for redirecting to a login page. If only `login_view` is
provided, this will just return the URL for it. If `next_url` is provided,
however, this will append a ``next=URL`` parameter to the query string
so that the login view can redirect back to that URL. Flask-Login's default
unauthorized handler uses this function when redirecting to your login url.
To force the host name used, set `FORCE_HOST_FOR_REDIRECTS` to a host. This
prevents from redirecting to external sites if request headers Host or
X-Forwarded-For are present.
:param login_view: The name of the login view. (Alternately, the actual
URL to the login view.)
:type login_view: str
:param next_url: The URL to give the login view for redirection.
:type next_url: str
:param next_field: What field to store the next URL in. (It defaults to
``next``.)
:type next_field: str
'''
base = expand_login_view(login_view)
if next_url is None:
return base
parsed_result = urlparse(base)
md = url_decode(parsed_result.query)
md[next_field] = make_next_param(base, next_url)
netloc = current_app.config.get('FORCE_HOST_FOR_REDIRECTS') or \
parsed_result.netloc
parsed_result = parsed_result._replace(netloc=netloc,
query=url_encode(md, sort=True))
return urlunparse(parsed_result)很直观,urlparse -> 替换 query 部分 -> urlunparse。
这当中调了一个私有方法 _replace,来对原 url 的query部分替换。
附上几段相关的代码,以更深入理解:
# urllib/parse.py
def urlparse(url, scheme='', allow_fragments=True):
"""Parse a URL into 6 components:
<scheme>://<netloc>/<path>;<params>?<query>#<fragment>
Return a 6-tuple: (scheme, netloc, path, params, query, fragment).
Note that we don't break the components up in smaller bits
(e.g. netloc is a single string) and we don't expand % escapes."""
url, scheme, _coerce_result = _coerce_args(url, scheme)
splitresult = urlsplit(url, scheme, allow_fragments)
scheme, netloc, url, query, fragment = splitresult
if scheme in uses_params and ';' in url:
url, params = _splitparams(url)
else:
params = ''
result = ParseResult(scheme, netloc, url, params, query, fragment)
return _coerce_result(result)
# urlparse 的结果是一个 ParseResult 对象。这个对象有 _replace 的保护方法。
class ParseResult(tuple):
'ParseResult(scheme, netloc, path, params, query, fragment)'
__slots__ = ()
_fields = ('scheme', 'netloc', 'path', 'params', 'query', 'fragment')
def __new__(_cls, scheme, netloc, path, params, query, fragment):
'Create new instance of ParseResult(scheme, netloc, path, params, query, fragment)'
return _tuple.__new__(_cls, (scheme, netloc, path, params, query, fragment))
@classmethod
def _make(cls, iterable, new=tuple.__new__, len=len):
'Make a new ParseResult object from a sequence or iterable'
result = new(cls, iterable)
if len(result) != 6:
raise TypeError('Expected 6 arguments, got %d' % len(result))
return result
def _replace(_self, **kwds):
'Return a new ParseResult object replacing specified fields with new values'
result = _self._make(map(kwds.pop, ('scheme', 'netloc', 'path', 'params', 'query', 'fragment'), _self))
if kwds:
raise ValueError('Got unexpected field names: %r' % list(kwds))
return result
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + '(scheme=%r, netloc=%r, path=%r, params=%r, query=%r, fragment=%r)' % self
def _asdict(self):
'Return a new OrderedDict which maps field names to their values.'
return OrderedDict(zip(self._fields, self))
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return tuple(self)
scheme = _property(_itemgetter(0), doc='Alias for field number 0')
netloc = _property(_itemgetter(1), doc='Alias for field number 1')
path = _property(_itemgetter(2), doc='Alias for field number 2')
params = _property(_itemgetter(3), doc='Alias for field number 3')
query = _property(_itemgetter(4), doc='Alias for field number 4')
fragment = _property(_itemgetter(5), doc='Alias for field number 5')从这段代码中可以看到 _replace 的定义。
def _replace(_self, **kwds):
'Return a new ParseResult object replacing specified fields with new values'将来自己构建 URL,也需要替换参数的时候,可以参考这段实现。
LocalProxy的使用
先看看用到LocalProxy的源码。
# flask_login/signals.py
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = LocalProxy(lambda: _get_user())
# flask_login/utils.py
def _get_user():
if has_request_context() and not hasattr(_request_ctx_stack.top, 'user'):
current_app.login_manager._load_user()
return getattr(_request_ctx_stack.top, 'user', None)
# flask_login/login_manager.py
def _load_user(self):
'''Loads user from session or remember_me cookie as applicable'''
user_accessed.send(current_app._get_current_object())
# first check SESSION_PROTECTION
config = current_app.config
if config.get('SESSION_PROTECTION', self.session_protection):
deleted = self._session_protection()
if deleted:
return self.reload_user()
# If a remember cookie is set, and the session is not, move the
# cookie user ID to the session.
#
# However, the session may have been set if the user has been
# logged out on this request, 'remember' would be set to clear,
# so we should check for that and not restore the session.
is_missing_user_id = 'user_id' not in session
if is_missing_user_id:
cookie_name = config.get('REMEMBER_COOKIE_NAME', COOKIE_NAME)
header_name = config.get('AUTH_HEADER_NAME', AUTH_HEADER_NAME)
has_cookie = (cookie_name in request.cookies and
session.get('remember') != 'clear')
if has_cookie:
return self._load_from_cookie(request.cookies[cookie_name])
elif self.request_callback:
return self._load_from_request(request)
elif header_name in request.headers:
return self._load_from_header(request.headers[header_name])
return self.reload_user()这个has_request_context是什么呢?
# flask/ctx.py
def has_request_context():
"""If you have code that wants to test if a request context is there or
not this function can be used. For instance, you may want to take advantage
of request information if the request object is available, but fail
silently if it is unavailable.
::
class User(db.Model):
def __init__(self, username, remote_addr=None):
self.username = username
if remote_addr is None and has_request_context():
remote_addr = request.remote_addr
self.remote_addr = remote_addr
Alternatively you can also just test any of the context bound objects
(such as :class:`request` or :class:`g` for truthness)::
class User(db.Model):
def __init__(self, username, remote_addr=None):
self.username = username
if remote_addr is None and request:
remote_addr = request.remote_addr
self.remote_addr = remote_addr
.. versionadded:: 0.7
"""
return _request_ctx_stack.top is not None再看看_request_ctx_stack:
# context locals
_request_ctx_stack = LocalStack()好了。先看到这里。我们先分析一下。再看其它源码。
这个LocalStack / LocalProxy很是神秘。为什么要这么用呢?其实也简单。
request.user是要保证每个线程独立的。你不能让张三登录之后,就把李四的线程中定义的user给变成张三。
这样就需要一个按线程隔离的存储区域。线程销毁的时候,要把自己的信息也顺带销毁。防止线程ID被复用时,登录信息也被复用了。python的ThreadLocal就是做这个事情的。
先不看LocalStack / LocalProxy的源码。如果我们自己实现,会怎么写呢?我会这样:
# 定义一个global变量
ctx = threading.local()
def login(username, password):
# ...
if authenticated:
ctx.user = user这样写也并不复杂啊。那为什么werkzeug库又要引入自己的一套东西呢。这方面的分析文章很多了。直接看就好。
结合来说,引入LocalStack的背景,就是第三个文章中的总结:
使用thread local对象虽然可以基于线程存储全局变量,但是在Web应用中可能会存在如下问题: 1. 有些应用使用的是greenlet协程,这种情况下无法保证协程之间数据的隔离,因为不同的协程可以在同一个线程当中。 2. 即使使用的是线程,WSGI应用也无法保证每个http请求使用的都是不同的线程,因为后一个http请求可能使用的是之前的http请求的线程,这样的话存储于thread local中的数据可能是之前残留的数据。 为了解决上述问题,Werkzeug开发了自己的local对象,这也是为什么我们需要Werkzeug的local对象