From 320615f9ffd0c7a420fe75c3969267b9fdf56b08 Mon Sep 17 00:00:00 2001 From: Burathar Date: Sat, 20 Mar 2021 13:27:46 +0100 Subject: [PATCH] Remove numeric ID, make name identifier. Fix errors --- biscd/biscd/froms.py | 19 ++-- biscd/biscd/models/user.py | 17 ++-- biscd/biscd/models/yaml_serializable.py | 115 ++++++++++++------------ biscd/biscd/routes.py | 16 ++-- 4 files changed, 84 insertions(+), 83 deletions(-) diff --git a/biscd/biscd/froms.py b/biscd/biscd/froms.py index d0c464f..25cd804 100644 --- a/biscd/biscd/froms.py +++ b/biscd/biscd/froms.py @@ -2,7 +2,7 @@ from flask_wtf import FlaskForm from wtforms import SubmitField, StringField, PasswordField, BooleanField from wtforms.validators import DataRequired, Email, EqualTo, Length, ValidationError -from .models import User +from .models import User, Project class LoginForm(FlaskForm): username = StringField('Username', validators=[DataRequired()]) @@ -11,23 +11,28 @@ class LoginForm(FlaskForm): submit = SubmitField('Sign In') class RegistrationForm(FlaskForm): - username = StringField('Username', validators=[DataRequired()]) + username = StringField('Username', validators=[DataRequired(), Length(min=4, max=64)]) email = StringField('Email', validators=[DataRequired(), Email()]) - password = PasswordField('Password', validators=[DataRequired()]) + password = PasswordField('Password', validators=[DataRequired(), Length(min=10, max=128)]) password2 = PasswordField( 'Repeat Password', validators=[DataRequired(), EqualTo('password')]) submit = SubmitField('Register') def validate_username(self, username): - user = User(username.data) - if user is not None: + user = User.get(name=username.data) + if not any(user): raise ValidationError('Please use a different username.') def validate_email(self, email): - user = User.query.filter_by(email=email.data).first() - if user is not None: + user = User.get(email=email.data) + if not any(user): raise ValidationError('Please use a different email adress.') class NewProjectForm(FlaskForm): projectname = StringField('Project Name', validators=[DataRequired()]) submit = SubmitField('Add Project') + + def validate_projectname(self, projectname): + project = Project.get(name=projectname.data) + if not any(project): + raise ValidationError('Please use a different projectname.') diff --git a/biscd/biscd/models/user.py b/biscd/biscd/models/user.py index 53ed511..aa751d6 100644 --- a/biscd/biscd/models/user.py +++ b/biscd/biscd/models/user.py @@ -16,18 +16,13 @@ class User(YamlSerializable): def _yaml_object_name(self): return 'users' - def __init__(self, id=None, name=None, email=None, password=None, password_hash=None): - super().__init__(id) + def __init__(self, name=None, email=None): self.name = name - self.password_hash = set_password(password, password_hash) self.email = email + self.password_hash = None - def set_password(password, password_hash): - if password_hash: - return password_hash - if password: - return generate_password_hash(password) - return None + def set_password(self, password): + self.password_hash = generate_password_hash(password) def check_password(self, password): if not password: @@ -35,5 +30,5 @@ class User(YamlSerializable): return check_password_hash(self.password_hash, password) @login.user_loader -def load_user(id): - return super.get(int(id)) +def load_user(name): + return super.get(name) diff --git a/biscd/biscd/models/yaml_serializable.py b/biscd/biscd/models/yaml_serializable.py index f7552c9..5424e16 100644 --- a/biscd/biscd/models/yaml_serializable.py +++ b/biscd/biscd/models/yaml_serializable.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from flask import abort import yaml class MyMeta(metaclass=ABCMeta): @@ -15,24 +16,6 @@ class MyMeta(metaclass=ABCMeta): class YamlSerializable(object): __metaclass__ = MyMeta required_attributes = ['name'] - _id_counter = 0 - - @classmethod - def initialize(cls): - ymlserializables = cls._get_all_from_file() - cls._id_counter = max(ymlserializable.id for ymlserializable in ymlserializables) + 1 - - @abstractmethod - def __init__(self, id = None): - self.id = self.set_id(id) - - @classmethod - def set_id(cls, id): - if id is not None: - return id - id = cls._id_counter - cls._id_counter += 1 - return id @classmethod @property @@ -53,37 +36,58 @@ class YamlSerializable(object): ymlserializable_dict.pop('name') return {self.name: ymlserializable_dict} - def save(self): + def save(self, overwrite=True): # pylint: disable=no-member if self.name is None: raise TypeError("Name cannot be None") - ymlserializables = self._get_all_from_file() - if self.name in ([*ymlserializable][0] for ymlserializable in ymlserializables): - ymlserializables[self.name] = self.config_dict + ymlsls = self._get_all_from_file() + if self.name in ([*ymlsl][0] for ymlsl in ymlsls): + if overwrite: + ymlsls[self.name] = self.config_dict + else: + raise ValueError(f"A {type(self).__name__} with name {self.name} already exists!") + else: + ymlsls.append(self.config_dict) + print(ymlsls) + self._save_all_to_file(ymlsls) + + @classmethod + def first_or_404(cls, **kwargs): + ymlsl = next(cls.get(**kwargs), None) + if ymlsl is None: + abort(404) else: - ymlserializables.append(self.config_dict) - print(ymlserializables) - self._save_all_to_file(ymlserializables) + return ymlsl @classmethod - def get(cls, identifier): - if isinstance(identifier, int): - id = identifier - ymlserializable_dict = next( - ymlserializable for ymlserializable in - cls._get_all_from_file() if int(ymlserializable['id']) == id - ) - return cls._ymlserializable_from_dict(ymlserializable_dict) - - if isinstance(identifier, str): - name = identifier - ymlserializable_dict = next( - ymlserializable for ymlserializable in - cls._get_all_from_file() if [*ymlserializable][0] == name - ) - return cls._ymlserializable_from_dict(ymlserializable_dict) - - return None + def get(cls, **kwargs): + """Returns any matching instances + + Filters all saved instances by specified properties. + All remaining items are returned as a list. + """ + + if not any(kwargs): + return [] + ymlsl_dicts = cls._get_all_from_file() + for key, value in kwargs.items(): + # 'name' has to be evaluated separately; 'name' is the key of the entire object + if key == 'name': + ymlsl_dicts = (ymlsl_dict for ymlsl_dict in ymlsl_dicts + if [*ymlsl_dict][0] == value) + + # For other keys, filter out any item that does not contain a key, + # or that not match the key's value + ymlsl_dicts = (ymlsl_dict for ymlsl_dict in ymlsl_dicts if ymlsl_dict.key == value) + + # After each iteration: if no item is left, return None + if not any(ymlsl_dicts): + return [] + + ymlsls = [] + for ymlsl_dict in ymlsl_dicts: + ymlsls.append(cls._ymlserializable_from_dict(ymlsl_dict)) + return ymlsls @classmethod def _ymlserializable_from_dict(cls, ymldict): @@ -100,33 +104,30 @@ class YamlSerializable(object): ymldict['name'] = ymlserializable_name # Create empty instance - ymlserializable = cls() + ymlsl = cls() # Fill instance with dict - ymlserializable.__dict__ = ymldict - return ymlserializable + ymlsl.__dict__ = ymldict + return ymlsl @classmethod def list(cls): - ymlserializables = cls._get_all_from_file() + ymlsls = cls._get_all_from_file() ymlserializables_list = [] - for ymlserializable in ymlserializables: - name = [*ymlserializable][0] + for ymlsl in ymlsls: + name = [*ymlsl][0] ymlserializables_list.append(name) return ymlserializables_list @classmethod def _get_all_from_file(cls): with open(cls._storage_file) as file: - ymlserializables = yaml.load(file, yaml.FullLoader).get(cls._yaml_object_name) + ymlsls = yaml.load(file, yaml.FullLoader).get(cls._yaml_object_name) - highest_id = max(ymlserializable.id for ymlserializable in ymlserializables) + 1 - if highest_id > cls._id_counter: cls._id_counter = highest_id - - return ymlserializables + return ymlsls @classmethod - def _save_all_to_file(cls, ymlserializables): - ymlserializables_object = {cls._yaml_object_name : ymlserializables} + def _save_all_to_file(cls, ymlsls): + ymlserializables_object = {cls._yaml_object_name : ymlsls} with open(cls._storage_file, 'w') as file: - yaml.dump(ymlserializables_object, file) + yaml.dump(ymlserializables_object, file) \ No newline at end of file diff --git a/biscd/biscd/routes.py b/biscd/biscd/routes.py index 88cc31b..4eab840 100644 --- a/biscd/biscd/routes.py +++ b/biscd/biscd/routes.py @@ -1,8 +1,9 @@ -from flask import render_template, flash, abort, redirect, request, url_for, url_parse +from flask import render_template, flash, abort, redirect, request, url_for from flask_login import current_user, login_user, logout_user, login_required +from werkzeug.urls import url_parse from biscd import app from .models import Project, User -from .froms import NewProjectForm, LoginForm +from .froms import NewProjectForm, LoginForm, RegistrationForm @app.route('/', methods=['GET', 'POST']) @app.route('/index', methods=['GET', 'POST']) @@ -10,7 +11,7 @@ from .froms import NewProjectForm, LoginForm def index(): form = NewProjectForm() if form.validate_on_submit(): - project = Project(form.projectname.data) + project = Project.first_or_404(name=form.projectname.data) project.save() flash('You added a project!') project_names = Project.list() @@ -22,7 +23,7 @@ def login(): return redirect(url_for('index')) form = LoginForm() if form.validate_on_submit(): - user = User.get(form.username.data) + user = User.first_or_404(name=form.username.data) if user is None or not user.check_password(form.password.data): flash('Invalid username or password') return redirect(url_for('login')) @@ -44,10 +45,9 @@ def register(): return redirect(url_for('index')) form = RegistrationForm() if form.validate_on_submit(): - user = User(username=form.username.data, email=form.email.data) + user = User(name=form.username.data, email=form.email.data) user.set_password(form.password.data) - db.session.add(user) - db.session.commit() + user.save() flash('Congratulations, you are now a registered user!') return redirect(url_for('login')) return render_template('register.html', title='Register', form=form) @@ -57,7 +57,7 @@ def register(): @login_required def project_dashboard(project_name): print(project_name) - project = Project.get(project_name) + project = Project.get(name=project_name) if project is None: abort(404) return render_template('project.html', project=project)