diff --git a/luncho/blueprints/users.py b/luncho/blueprints/users.py index aab6410..9b2b8c4 100644 --- a/luncho/blueprints/users.py +++ b/luncho/blueprints/users.py @@ -6,13 +6,13 @@ from flask import Blueprint from flask import request from flask import jsonify -# from flask import current_app -# from pony.orm import commit +from sqlalchemy.exc import IntegrityError from luncho.helpers import ForceJSON from luncho.server import User +from luncho.server import db users = Blueprint('users', __name__) @@ -23,14 +23,19 @@ def create_user(): """Create a new user. Request must be: { "username": "username", "full_name": "Full Name", "password": "hash" }""" json = request.get_json(force=True) - # new_user = User(username=json['username'], - # fullname=json['full_name'], - # passhash=json['password'], - # validated=False) - User(username=json['username'], - fullname=json['full_name'], - passhash=json['password'], - validated=False) - # commit() - - return jsonify(status='OK') + + try: + new_user = User(username=json['username'], + fullname=json['full_name'], + passhash=json['password'], + validated=False) + + db.session.add(new_user) + db.session.commit() + + return jsonify(status='OK') + except IntegrityError: + resp = jsonify(status='ERROR', + error='username already exists') + resp.status_code = 409 + return resp diff --git a/luncho/helpers.py b/luncho/helpers.py index 2b9cca9..adc3ee6 100644 --- a/luncho/helpers.py +++ b/luncho/helpers.py @@ -18,8 +18,10 @@ class ForceJSON(object): def check_json(*args, **kwargs): json = request.get_json(force=True, silent=True) if not json: - return jsonify(status='ERROR', - error='Request MUST be in JSON format'), 400 + resp = jsonify(status='ERROR', + error='Request MUST be in JSON format') + resp.status_code = 400 + return resp # now we have the JSON, let's check if all the fields are here. missing = [] @@ -30,8 +32,10 @@ class ForceJSON(object): if missing: fields = ', '.join(missing) error = 'Missing fields: {fields}'.format(fields=fields) - return jsonify(status='ERROR', + resp = jsonify(status='ERROR', error=error) + resp.status_code = 400 + return resp return func(*args, **kwargs) return check_json diff --git a/tests/users_tests.py b/tests/users_tests.py index 0873644..36396bc 100644 --- a/tests/users_tests.py +++ b/tests/users_tests.py @@ -15,21 +15,64 @@ class TestUsers(unittest.TestCase): server.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://' server.app.config['TESTING'] = True - print server.app.config['SQLALCHEMY_DATABASE_URI'] self.app = server.app.test_client() + server.db.create_all() - # def tearDown(self): - # os.unlink(server.app.config['SQLITE_FILENAME']) + def tearDown(self): + server.db.drop_all(bind=None) def test_create_user(self): + """Simple user creation.""" request = {'username': 'username', 'full_name': 'full name', 'password': 'hash'} rv = self.app.put('/user/', data=json.dumps(request), content_type='application/json') + self.assertEqual(rv.status_code, 200) self.assertEqual(json.loads(rv.data), {'status': 'OK'}) + def test_duplicate_user(self): + """Check the status for trying to create a user that it is already + in the database.""" + self.test_create_user() # create the first user + + # now duplicate + request = {'username': 'username', + 'full_name': 'full name', + 'password': 'hash'} + rv = self.app.put('/user/', + data=json.dumps(request), + content_type='application/json') + + expected = {"status": "ERROR", + "error": "username already exists"} + + self.assertEqual(rv.status_code, 409) + self.assertEqual(json.loads(rv.data), expected) + + def test_no_json(self): + """Check the status when doing a request that it's not JSON.""" + rv = self.app.put('/user/', + data='', + content_type='text/html') + + expected = {"error": "Request MUST be in JSON format", + "status": "ERROR"} + self.assertEqual(rv.status_code, 400) + self.assertEqual(json.loads(rv.data), expected) + + def test_missing_fields(self): + request = {'password': 'hash'} + rv = self.app.put('/user/', + data=json.dumps(request), + content_type='application/json') + + resp = {'error': 'Missing fields: username, full_name', + 'status': 'ERROR'} + self.assertEqual(rv.status_code, 400) + self.assertEqual(json.loads(rv.data), resp) + if __name__ == '__main__': unittest.main()