| from flask import flash |
| from requests.auth import HTTPBasicAuth |
| |
| import APIFactory |
| import config |
| |
| |
| __author__ = 'hanl' |
| |
| |
| |
| # instance reference for authentication provider |
| # must be set before usage! |
| provider = None |
| message_handler = None |
| |
| |
| def load_provider(provider_name, handler): |
| ''' |
| :param provider_name: |
| :return: |
| ''' |
| global message_handler |
| message_handler = handler |
| |
| split = provider_name.split('.') |
| # get last element, so you have the class name |
| _class = split[len(split) - 1] |
| provider_name = provider_name.replace("." + _class, "") |
| module = __import__(provider_name) |
| obj = getattr(module, _class, None) |
| instance = obj() |
| if instance is None: |
| raise KeyError("the provider class %s is undefined!" % str(_class)) |
| print "successfully loaded provider '%s'" % str(_class) |
| global provider |
| provider = instance |
| return instance |
| |
| |
| class User(object): |
| ''' |
| the user object |
| ''' |
| |
| def __init__(self, username, password=None, email=None, firstName=None, lastName=None, address=None, |
| institution=None, |
| phone=None, country=None, id_token=None): |
| print "the username %s" % str(username) |
| if not username: |
| raise ValueError("the username must be set") |
| self.username = username |
| self.password = password |
| self.email = email |
| self.firstName = firstName |
| self.lastName = lastName |
| self.address = address |
| self.institution = institution |
| self.phone = phone |
| self.country = country |
| self.id_token = id_token |
| |
| def is_authenticated(self): |
| return True |
| |
| def has_details(self): |
| if self.firstName and self.lastName and self.email: |
| return True |
| return False |
| |
| def is_anonymous(self): |
| return False |
| |
| def get_id(self): |
| ''' |
| :return: id reference to retrieve user object from middleware |
| ''' |
| return unicode(self.username) |
| |
| def is_active(self): |
| return True |
| |
| def get_full_name(self): |
| if self.firstName and self.lastName: |
| return u' '.join([self.firstName, self.lastName]) |
| return None |
| |
| |
| class Provider(object): |
| global message_handler |
| |
| def __init__(self): |
| self.handler = message_handler |
| |
| def authenticate(self, username=None, password=None): |
| pass |
| |
| def get_user(self, username=None, session=None, full=False): |
| if not username or not session: |
| return ValueError("username and session must be provided!") |
| |
| def login(self, session=None, user=None): |
| pass |
| |
| |
| def logout(self, session=None): |
| pass |
| |
| def is_authenticated(self): |
| pass |
| |
| |
| class CustomProvider(Provider): |
| def authenticate(self, username=None, password=None): |
| pass |
| |
| def get_user(self, username=None, session=None, full=False): |
| """ |
| Returns the user model instance associated with the given request session. |
| If no user is retrieved an instance of `AnonymousUser` is returned. |
| """ |
| # call super method |
| super(CustomProvider, self).get_user(username, session, full) |
| if full: |
| id_token = session['api_token'] |
| code = APIFactory.decrypt_openid(token=id_token) |
| user = User(username=username, email=code['email'], firstName=code['firstName'], |
| lastName=code['lastName'], address=code['address'], institution=code['institution'], |
| phone=code['phone']) |
| return user |
| else: |
| return User(username=username) |
| |
| def login(self, session=None, user=None): |
| ''' |
| :param login_func: client specific login function |
| :param user: user object to register user for |
| :return: boolean if login successful |
| ''' |
| super(CustomProvider, self).login(session, user) |
| |
| response = APIFactory.get("auth/requestToken", auth=HTTPBasicAuth(username=user.username, |
| password=user.password)) |
| user.password = None |
| |
| if response is None: |
| return False |
| elif self.handler.isError(response): |
| self.handler.notifyNext(response.json(), flash) |
| return False |
| print "the response %i:%s" % (response.status_code, str(response.content)) |
| session['api_token'] = response.content.replace('api_token ', '') |
| return True |
| |
| def logout(self, session=None): |
| if 'api_token' not in session: |
| return False |
| session.pop('api_token', None) |
| return True |
| |
| def is_authenticated(self): |
| ''' |
| check that oauth id_token and access_token are not expired! |
| :return: |
| ''' |
| pass |
| |
| |
| class OAuth2Provider(Provider): |
| def authenticate(self, username=None, password=None): |
| pass |
| |
| def get_user(self, username=None, session=None, full=False): |
| """ |
| Returns the user model instance associated with the given request session. |
| If no user is retrieved an instance of `AnonymousUser` is returned. |
| """ |
| # call super method |
| super(OAuth2Provider, self).get_user(username, session, full) |
| |
| if full and "openid" in config.OPENID_CONNECT_SCOPES and "profile" in config.OPENID_CONNECT_SCOPES: |
| id_token = session['id_token'] |
| code = APIFactory.decrypt_openid(secret=config.OAUTH2_CLIENT_SECRET, token=id_token) |
| user = User(username=username, email=code['email'], firstName=code['firstName'], |
| lastName=code['lastName'], address=code['address'], institution=code['institution'], |
| phone=code['phone']) |
| return user |
| elif full: |
| response = APIFactory.get("user/info", auth=APIFactory.Oauth2Auth(session['access_token'])) |
| if response is None: |
| return None |
| elif self.handler.isError(response): |
| self.handler.notifyNext(response.json(), flash) |
| return None |
| else: |
| code = response.json() |
| user = User(username=username, email=code['email'], firstName=code['firstName'], |
| lastName=code['lastName'], address=code['address'], institution=code['institution'], |
| phone=code['phone']) |
| return user |
| else: |
| # for the most tasks its only about to have a user object, not the actual data! |
| return User(username=username) |
| |
| def login(self, session=None, user=None): |
| ''' |
| :param login_func: client specific login function |
| :param user: user object to register user for |
| :return: boolean if login successful |
| ''' |
| super(OAuth2Provider, self).login(session, user) |
| |
| params = {"username": user.username, "password": user.password, |
| "grant_type": "password", "client_id": config.OAUTH2_CLIENT_ID, |
| "client_secret": config.OAUTH2_CLIENT_SECRET, "scope": config.OPENID_CONNECT_SCOPES} |
| response = APIFactory.post(path='oauth2/token', params=params) |
| user.password = None |
| |
| if response is None: |
| return False |
| elif self.handler.isError(response): |
| self.handler.notifyNext(response.json(), flash) |
| return False |
| print "the response %i:%s" % (response.status_code, str(response.content)) |
| |
| session['access_token'] = response.json()['access_token'] |
| if "openid" in config.OPENID_CONNECT_SCOPES: |
| session['id_token'] = response.json()['id_token'] |
| else: |
| session['id_token'] = "some random string" # todo ??? |
| return True |
| |
| |
| def logout(self, session=None): |
| if 'access_token' not in session: |
| return False |
| |
| session.pop('access_token', None) |
| if 'id_token' in session: |
| session.pop('id_token', None) |
| return True |
| |
| |
| def is_authenticated(self): |
| ''' |
| check that oauth id_token and access_token are not expired! |
| add function to auth decorator |
| :return: |
| ''' |
| pass |