From 5e638806b5fd897a4939fd72d6c913eecc6b0ac3 Mon Sep 17 00:00:00 2001 From: Eriks K Date: Wed, 3 Feb 2021 00:47:52 +0200 Subject: [PATCH] Switch from manual log creation to Python logging --- erepublik/citizen.py | 65 ++++++++-- erepublik/classes.py | 37 +----- erepublik/logging.py | 157 ++++++++++++++++++++++++ erepublik/utils.py | 279 +++++++++++++------------------------------ 4 files changed, 299 insertions(+), 239 deletions(-) create mode 100644 erepublik/logging.py diff --git a/erepublik/citizen.py b/erepublik/citizen.py index fc9ee7e..a8e1a5c 100644 --- a/erepublik/citizen.py +++ b/erepublik/citizen.py @@ -1,5 +1,5 @@ +import logging import re -import sys import warnings import weakref from datetime import datetime, time, timedelta @@ -13,6 +13,7 @@ from requests import HTTPError, RequestException, Response from . import access_points, classes, constants, types, utils from .classes import OfferItem +from .logging import ErepublikLogConsoleHandler, ErepublikFormatter, ErepublikFileHandler, ErepublikErrorHTTTPHandler class BaseCitizen(access_points.CitizenAPI): @@ -44,6 +45,8 @@ class BaseCitizen(access_points.CitizenAPI): stop_threads: Event = None telegram: classes.TelegramReporter = None + logger: logging.Logger + r: Response = None name: str = 'Not logged in!' logged_in: bool = False @@ -58,6 +61,9 @@ class BaseCitizen(access_points.CitizenAPI): self.my_companies = classes.MyCompanies(self) self.reporter = classes.Reporter(self) self.stop_threads = Event() + logger_class = logging.getLoggerClass() + self.logger = logger_class('Citizen') + self.telegram = classes.TelegramReporter(stop_event=self.stop_threads) self.config.email = email @@ -465,17 +471,14 @@ class BaseCitizen(access_points.CitizenAPI): self._inventory.offers = offers self.food['total'] = sum([self.food[q] * constants.FOOD_ENERGY[q] for q in constants.FOOD_ENERGY]) - def write_log(self, *args, **kwargs): - if self.config.interactive: - utils.write_interactive_log(*args, **kwargs) - else: - utils.write_silent_log(*args, **kwargs) + def write_log(self, msg: str): + self.logger.info(msg) def report_error(self, msg: str = "", is_warning: bool = False): if is_warning: - utils.process_warning(msg, self.name, sys.exc_info(), self) + self.logger.warning(msg) else: - utils.process_error(msg, self.name, sys.exc_info(), self, None, None) + self.logger.error(msg) def sleep(self, seconds: Union[int, float, Decimal]): if seconds < 0: @@ -485,9 +488,40 @@ class BaseCitizen(access_points.CitizenAPI): else: sleep(seconds) - def set_debug(self, debug: bool): - self.debug = bool(debug) - self._req.debug = bool(debug) + def init_logger(self): + for handler in list(self.logger.handlers): + self.logger.removeHandler(handler) + formatter = ErepublikFormatter() + if self.config.interactive: + console_handler = ErepublikLogConsoleHandler() + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + file_handler = ErepublikFileHandler() + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + error_handler = ErepublikErrorHTTTPHandler(self.reporter) + error_handler.setFormatter(formatter) + self.logger.addHandler(error_handler) + self.logger.setLevel(logging.INFO) + + def set_debug(self, enable: bool): + self.debug = bool(enable) + self._req.debug = bool(enable) + self.logger.setLevel(logging.DEBUG if enable else logging.INFO) + + for handler in self.logger.handlers: + if isinstance(handler, (ErepublikLogConsoleHandler, ErepublikFileHandler)): + handler.setLevel(logging.DEBUG if enable else logging.INFO) + self.logger.debug(f"Debug messages {'enabled' if enable else 'disabled'}!") + + def set_interactive(self, enable: bool): + for handler in self.logger.handlers: + if isinstance(handler, (ErepublikLogConsoleHandler,)): + self.logger.removeHandler(handler) + if enable: + handler = ErepublikLogConsoleHandler() + handler.setLevel(logging.DEBUG if self.debug else logging.INFO) + self.logger.addHandler(handler) def to_json(self, indent: bool = False) -> str: return utils.json.dumps(self, cls=classes.ErepublikJSONEncoder, indent=4 if indent else None, sort_keys=True) @@ -517,6 +551,7 @@ class BaseCitizen(access_points.CitizenAPI): for k, v in data.get('config', {}).items(): if hasattr(player.config, k): setattr(player.config, k, v) + player.init_logger() player._resume_session() return player @@ -736,8 +771,7 @@ class BaseCitizen(access_points.CitizenAPI): self.r = r if r.url == f"{self.url}/login": - self.write_log("Citizen email and/or password is incorrect!") - raise KeyboardInterrupt + self.logger.error("Citizen email and/or password is incorrect!") else: re_name_id = re.search(r'', r.text) @@ -2635,6 +2669,11 @@ class _Citizen(CitizenAnniversary, CitizenCompanies, CitizenLeaderBoard, self.telegram.send_message(f"*Started* {utils.now():%F %T}") self.update_all(True) + for handler in self.logger.handlers: + if isinstance(handler, ErepublikErrorHTTTPHandler): + self.logger.removeHandler(handler) + break + self.logger.addHandler(ErepublikErrorHTTTPHandler(self.reporter)) def update_citizen_info(self, html: str = None): """ diff --git a/erepublik/classes.py b/erepublik/classes.py index 1147d89..58b1b1c 100644 --- a/erepublik/classes.py +++ b/erepublik/classes.py @@ -10,8 +10,8 @@ from requests import Response, Session, post from . import constants, types, utils __all__ = ['Battle', 'BattleDivision', 'BattleSide', 'Company', 'Config', 'Details', 'Energy', 'ErepublikException', - 'ErepublikJSONEncoder', 'ErepublikNetworkException', 'EnergyToFight', 'Holding', 'Inventory', 'MyCompanies', - 'OfferItem', 'Politics', 'Reporter', 'TelegramReporter', ] + 'ErepublikNetworkException', 'EnergyToFight', 'Holding', 'Inventory', 'MyCompanies', 'OfferItem', 'Politics', + 'Reporter', 'TelegramReporter', ] class ErepublikException(Exception): @@ -482,6 +482,8 @@ class Details: def __init__(self): self.next_pp = [] self.mayhem_skills = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0} + _default_country = constants.Country(0, 'Unknown', 'Unknown', 'XX') + self.citizenship = self.current_country = self.residence_country = _default_country @property def xp_till_level_up(self): @@ -603,10 +605,10 @@ class Reporter: if self.__to_update: for unreported_data in self.__to_update: unreported_data.update(player_id=self.citizen_id, key=self.key) - unreported_data = utils.json.loads(utils.json.dumps(unreported_data, cls=ErepublikJSONEncoder)) + unreported_data = utils.json_loads(utils.json_dumps(unreported_data)) self._req.post(f"{self.url}/bot/update", json=unreported_data) self.__to_update.clear() - data = utils.json.loads(utils.json.dumps(data, cls=ErepublikJSONEncoder)) + data = utils.json.loads(utils.json_dumps(data)) r = self._req.post(f"{self.url}/bot/update", json=data) return r @@ -687,33 +689,6 @@ class Reporter: return [] -class ErepublikJSONEncoder(utils.json.JSONEncoder): - def default(self, o): - from erepublik.citizen import Citizen - if isinstance(o, Decimal): - return float(f"{o:.02f}") - elif isinstance(o, datetime.datetime): - return dict(__type__='datetime', date=o.strftime("%Y-%m-%d"), time=o.strftime("%H:%M:%S"), - tzinfo=str(o.tzinfo) if o.tzinfo else None) - elif isinstance(o, datetime.date): - return dict(__type__='date', date=o.strftime("%Y-%m-%d")) - elif isinstance(o, datetime.timedelta): - return dict(__type__='timedelta', days=o.days, seconds=o.seconds, - microseconds=o.microseconds, total_seconds=o.total_seconds()) - elif isinstance(o, Response): - return dict(headers=dict(o.__dict__['headers']), url=o.url, text=o.text, status_code=o.status_code) - elif hasattr(o, 'as_dict'): - return o.as_dict - elif isinstance(o, set): - return list(o) - elif isinstance(o, Citizen): - return o.to_json() - try: - return super().default(o) - except Exception as e: # noqa - return 'Object is not JSON serializable' - - class BattleSide: points: int deployed: List[constants.Country] diff --git a/erepublik/logging.py b/erepublik/logging.py new file mode 100644 index 0000000..87d67ba --- /dev/null +++ b/erepublik/logging.py @@ -0,0 +1,157 @@ +import base64 +import datetime +import inspect +import logging +import os +import sys +import weakref +from pathlib import Path + +import requests +from logging import handlers, LogRecord +from typing import Union, Dict, Any + +from erepublik.classes import Reporter +from erepublik.constants import erep_tz +from erepublik.utils import slugify, json_loads, json, now, json_dumps + + +class ErepublikFileHandler(handlers.TimedRotatingFileHandler): + _file_path: Path + + def __init__(self, filename: str = 'log/erepublik.log', *args, **kwargs): + log_path = Path(filename) + self._file_path = log_path + log_path.parent.mkdir(parents=True, exist_ok=True) + at_time = erep_tz.localize(datetime.datetime.now()).replace(hour=0, minute=0, second=0, microsecond=0) + kwargs.update(atTime=at_time) + super().__init__(filename, *args, **kwargs) + + def doRollover(self) -> None: + self._file_path.parent.mkdir(parents=True, exist_ok=True) + super().doRollover() + + def emit(self, record: LogRecord) -> None: + self._file_path.parent.mkdir(parents=True, exist_ok=True) + super().emit(record) + + +class ErepublikLogConsoleHandler(logging.StreamHandler): + def __init__(self, *_): + super().__init__(sys.stdout) + + +class ErepublikFormatter(logging.Formatter): + """override logging.Formatter to use an aware datetime object""" + + dbg_fmt = "[%(asctime)s] DEBUG: %(module)s: %(lineno)d: %(msg)s" + info_fmt = "[%(asctime)s] %(msg)s" + default_fmt = "[%(asctime)s] %(levelname)s: %(msg)s" + + def converter(self, timestamp: Union[int, float]) -> datetime.datetime: + return datetime.datetime.utcfromtimestamp(timestamp).astimezone(erep_tz) + + def format(self, record: logging.LogRecord) -> str: + if record.levelno == logging.DEBUG: + self._fmt = self.dbg_fmt + elif record.levelno == logging.INFO: + self._fmt = self.info_fmt + else: + self._fmt = self.default_fmt + self._style = logging.PercentStyle(self._fmt) + return super().format(record) + + def formatTime(self, record, datefmt=None): + dt = self.converter(record.created) + if datefmt: + s = dt.strftime(datefmt) + else: + s = dt.strftime('%Y-%m-%d %H:%M:%S') + return s + + +class ErepublikErrorHTTTPHandler(handlers.HTTPHandler): + def __init__(self, reporter: Reporter): + logging.Handler.__init__(self, level=logging.ERROR) + self._reporter = weakref.ref(reporter) + self.host = 'localhost:5000' + self.url = '/ebot/error/' + self.method = 'POST' + self.secure = False + self.credentials = (str(reporter.citizen_id), reporter.key) + self.context = None + + @property + def reporter(self): + return self._reporter() + + def mapLogRecord(self, record: logging.LogRecord) -> Dict[str, Any]: + data = super().mapLogRecord(record) + + # Log last response + response = self.reporter.citizen.r + url = response.url + last_index = url.index("?") if "?" in url else len(response.url) + + name = slugify(response.url[len(self.reporter.citizen.url):last_index]) + html = response.text + + try: + json_loads(html) + ext = 'json' + except json.decoder.JSONDecodeError: + ext = 'html' + try: + resp_time = datetime.datetime.strptime( + response.headers.get('date'), '%a, %d %b %Y %H:%M:%S %Z' + ).replace(tzinfo=datetime.timezone.utc).astimezone(erep_tz).strftime('%F_%H-%M-%S') + except: + resp_time = slugify(response.headers.get('date')) + + resp = dict(name=f"{resp_time}_{name}.{ext}", content=html.encode('utf-8'), + mimetype="application/json" if ext == 'json' else "text/html") + + files = [('file', (resp.get('name'), resp.get('content'), resp.get('mimetype'))), ] + filename = f'log/{now().strftime("%F")}.log' + if os.path.isfile(filename): + files.append(('file', (filename[4:], open(filename, 'rb'), 'text/plain'))) + trace = inspect.trace() + local_vars = {} + if trace: + local_vars = trace[-1][0].f_locals + if local_vars.get('__name__') == '__main__': + local_vars.update(commit_id=local_vars.get('COMMIT_ID'), interactive=local_vars.get('INTERACTIVE'), + version=local_vars.get('__version__'), config=local_vars.get('CONFIG')) + + if local_vars: + if 'state_thread' in local_vars: + local_vars.pop('state_thread', None) + + if isinstance(local_vars.get('self'), self.reporter.citizen.__class__): + local_vars['self'] = repr(local_vars['self']) + if isinstance(local_vars.get('player'), self.reporter.citizen.__class__): + local_vars['player'] = repr(local_vars['player']) + if isinstance(local_vars.get('citizen'), self.reporter.citizen.__class__): + local_vars['citizen'] = repr(local_vars['citizen']) + + files.append(('file', ('local_vars.json', json_dumps(local_vars), "application/json"))) + files.append(('file', ('instance.json', self.reporter.citizen.to_json(indent=True), "application/json"))) + data.update(files=files) + return data + + def emit(self, record): + """ + Emit a record. + + Send the record to the Web server as a percent-encoded dictionary + """ + try: + proto = 'https' if self.secure else 'http' + u, p = self.credentials + s = 'Basic ' + base64.b64encode(f'{u}:{p}'.encode('utf-8')).strip().decode('ascii') + headers = {'Authorization': s} + data = self.mapLogRecord(record) + files = data.pop('files') if 'files' in data else None + requests.post(f"{proto}://{self.host}{self.url}", headers=headers, data=data, files=files) + except Exception: + self.handleError(record) diff --git a/erepublik/utils.py b/erepublik/utils.py index af28eb9..0561149 100644 --- a/erepublik/utils.py +++ b/erepublik/utils.py @@ -1,10 +1,8 @@ import datetime -import inspect import os import re import sys import time -import traceback import unicodedata import warnings from base64 import b64encode @@ -14,6 +12,7 @@ from typing import Any, Dict, List, Union import pytz import requests +from requests import Response from . import __version__, constants @@ -22,12 +21,12 @@ try: except ImportError: import json -__all__ = ['VERSION', 'calculate_hit', 'caught_error', 'date_from_eday', 'eday_from_date', 'deprecation', - 'get_air_hit_dmg_value', 'get_file', 'get_ground_hit_dmg_value', 'get_sleep_seconds', 'good_timedelta', - 'interactive_sleep', 'json', 'localize_dt', 'localize_timestamp', 'normalize_html_json', 'now', - 'process_error', 'process_warning', 'send_email', 'silent_sleep', 'slugify', 'write_file', 'write_request', - 'write_interactive_log', 'write_silent_log', 'get_final_hit_dmg', 'wait_for_lock', - 'json_decode_object_hook', 'json_load', 'json_loads'] +__all__ = [ + 'VERSION', 'calculate_hit', 'date_from_eday', 'eday_from_date', 'deprecation', 'get_final_hit_dmg', 'write_file', + 'get_air_hit_dmg_value', 'get_file', 'get_ground_hit_dmg_value', 'get_sleep_seconds', 'good_timedelta', 'slugify', + 'interactive_sleep', 'json', 'localize_dt', 'localize_timestamp', 'normalize_html_json', 'now', 'silent_sleep', + 'json_decode_object_hook', 'json_load', 'json_loads', 'json_dump', 'json_dumps', 'b64json', 'ErepublikJSONEncoder', +] VERSION: str = __version__ @@ -103,25 +102,25 @@ def interactive_sleep(sleep_seconds: int): silent_sleep = time.sleep -def _write_log(msg, timestamp: bool = True, should_print: bool = False): - erep_time_now = now() - txt = f"[{erep_time_now.strftime('%F %T')}] {msg}" if timestamp else msg - if not os.path.isdir('log'): - os.mkdir('log') - with open(f'log/{erep_time_now.strftime("%F")}.log', 'a', encoding='utf-8') as f: - f.write(f'{txt}\n') - if should_print: - print(txt) - - -def write_interactive_log(*args, **kwargs): - kwargs.pop('should_print', None) - _write_log(should_print=True, *args, **kwargs) - - -def write_silent_log(*args, **kwargs): - kwargs.pop('should_print', None) - _write_log(should_print=False, *args, **kwargs) +# def _write_log(msg, timestamp: bool = True, should_print: bool = False): +# erep_time_now = now() +# txt = f"[{erep_time_now.strftime('%F %T')}] {msg}" if timestamp else msg +# if not os.path.isdir('log'): +# os.mkdir('log') +# with open(f'log/{erep_time_now.strftime("%F")}.log', 'a', encoding='utf-8') as f: +# f.write(f'{txt}\n') +# if should_print: +# print(txt) +# +# +# def write_interactive_log(*args, **kwargs): +# kwargs.pop('should_print', None) +# _write_log(should_print=True, *args, **kwargs) +# +# +# def write_silent_log(*args, **kwargs): +# kwargs.pop('should_print', None) +# _write_log(should_print=False, *args, **kwargs) def get_file(filepath: str) -> str: @@ -155,89 +154,6 @@ def write_file(filename: str, content: str) -> int: return ret -def write_request(response: requests.Response, is_error: bool = False): - from erepublik import Citizen - - # Remove GET args from url name - url = response.url - last_index = url.index("?") if "?" in url else len(response.url) - - name = slugify(response.url[len(Citizen.url):last_index]) - html = response.text - - try: - json.loads(html) - ext = 'json' - except json.decoder.JSONDecodeError: - ext = 'html' - - if not is_error: - filename = f"debug/requests/{now().strftime('%F_%H-%M-%S')}_{name}.{ext}" - write_file(filename, html) - else: - return dict(name=f"{now().strftime('%F_%H-%M-%S')}_{name}.{ext}", content=html.encode('utf-8'), - mimetype="application/json" if ext == 'json' else "text/html") - - -def send_email(name: str, content: List[Any], player=None, local_vars: Dict[str, Any] = None, - promo: bool = False, captcha: bool = False): - if local_vars is None: - local_vars = {} - from erepublik import Citizen - - file_content_template = '{title}{body}' - if isinstance(player, Citizen) and player.r: - resp = write_request(player.r, is_error=True) - else: - resp = dict(name='None.html', mimetype='text/html', - content=file_content_template.format(body='
'.join(content), title='Error')) - - if promo: - resp = dict(name=f"{name}.html", mimetype='text/html', - content=file_content_template.format(title='Promo', body='
'.join(content))) - subject = f"[eBot][{now().strftime('%F %T')}] Promos: {name}" - - elif captcha: - resp = dict(name=f'{name}.html', mimetype='text/html', - content=file_content_template.format(title='ReCaptcha', body='
'.join(content))) - subject = f"[eBot][{now().strftime('%F %T')}] RECAPTCHA: {name}" - else: - subject = f"[eBot][{now().strftime('%F %T')}] Bug trace: {name}" - - body = "".join(traceback.format_stack()) + \ - "\n\n" + \ - "\n".join(content) - data = dict(send_mail=True, subject=subject, bugtrace=body) - if promo: - data.update(promo=True) - elif captcha: - data.update(captcha=True) - else: - data.update(bug=True) - - files = [('file', (resp.get('name'), resp.get('content'), resp.get('mimetype'))), ] - filename = f'log/{now().strftime("%F")}.log' - if os.path.isfile(filename): - files.append(('file', (filename[4:], open(filename, 'rb'), 'text/plain'))) - if local_vars: - if 'state_thread' in local_vars: - local_vars.pop('state_thread', None) - - if isinstance(local_vars.get('self'), Citizen): - local_vars['self'] = repr(local_vars['self']) - if isinstance(local_vars.get('player'), Citizen): - local_vars['player'] = repr(local_vars['player']) - if isinstance(local_vars.get('citizen'), Citizen): - local_vars['citizen'] = repr(local_vars['citizen']) - - from erepublik.classes import ErepublikJSONEncoder - files.append(('file', ('local_vars.json', json.dumps(local_vars, cls=ErepublikJSONEncoder), - "application/json"))) - if isinstance(player, Citizen): - files.append(('file', ('instance.json', player.to_json(indent=True), "application/json"))) - requests.post('https://pasts.72.lv', data=data, files=files) - - def normalize_html_json(js: str) -> str: js = re.sub(r' \'(.*?)\'', lambda a: f'"{a.group(1)}"', js) js = re.sub(r'(\d\d):(\d\d):(\d\d)', r'\1\2\3', js) @@ -246,72 +162,6 @@ def normalize_html_json(js: str) -> str: return js -def caught_error(e: Exception): - process_error(str(e), 'Unclassified', sys.exc_info(), interactive=False) - - -def process_error(log_info: str, name: str, exc_info: tuple, citizen=None, commit_id: str = None, - interactive: bool = None): - """ - Process error logging and email sending to developer - :param interactive: Should print interactively - :type interactive: bool - :param log_info: String to be written in output - :type log_info: str - :param name: String Instance name - :type name: str - :param exc_info: tuple output from sys.exc_info() - :type exc_info: tuple - :param citizen: Citizen instance - :type citizen: Citizen - :param commit_id: Caller's code version's commit id - :type commit_id: str - """ - type_, value_, traceback_ = exc_info - content = [log_info] - content += [f"eRepublik version {VERSION}"] - if commit_id: - content += [f"Commit id {commit_id}"] - content += [str(value_), str(type_), ''.join(traceback.format_tb(traceback_))] - - if interactive: - write_interactive_log(log_info) - elif interactive is not None: - write_silent_log(log_info) - trace = inspect.trace() - if trace: - local_vars = trace[-1][0].f_locals - if local_vars.get('__name__') == '__main__': - local_vars.update(commit_id=local_vars.get('COMMIT_ID'), interactive=local_vars.get('INTERACTIVE'), - version=local_vars.get('__version__'), config=local_vars.get('CONFIG')) - else: - local_vars = dict() - send_email(name, content, citizen, local_vars=local_vars) - - -def process_warning(log_info: str, name: str, exc_info: tuple, citizen=None, commit_id: str = None): - """ - Process error logging and email sending to developer - :param log_info: String to be written in output - :param name: String Instance name - :param exc_info: tuple output from sys.exc_info() - :param citizen: Citizen instance - :param commit_id: Code's version by commit id - """ - type_, value_, traceback_ = exc_info - content = [log_info] - if commit_id: - content += [f'Commit id: {commit_id}'] - content += [str(value_), str(type_), ''.join(traceback.format_tb(traceback_))] - - trace = inspect.trace() - if trace: - local_vars = trace[-1][0].f_locals - else: - local_vars = dict() - send_email(name, content, citizen, local_vars=local_vars) - - def slugify(value, allow_unicode=False) -> str: """ Function copied from Django2.2.1 django.utils.text.slugify @@ -378,25 +228,25 @@ def deprecation(message): warnings.warn(message, DeprecationWarning, stacklevel=2) -def wait_for_lock(function): - def wrapper(instance, *args, **kwargs): - if not instance.concurrency_available.wait(600): - e = 'Concurrency not freed in 10min!' - instance.write_log(e) - if instance.debug: - instance.report_error(e) - return None - else: - instance.concurrency_available.clear() - try: - ret = function(instance, *args, **kwargs) - except Exception as e: - instance.concurrency_available.set() - raise e - instance.concurrency_available.set() - return ret - - return wrapper +# def wait_for_lock(function): +# def wrapper(instance, *args, **kwargs): +# if not instance.concurrency_available.wait(600): +# e = 'Concurrency not freed in 10min!' +# instance.write_log(e) +# if instance.debug: +# instance.report_error(e) +# return None +# else: +# instance.concurrency_available.clear() +# try: +# ret = function(instance, *args, **kwargs) +# except Exception as e: +# instance.concurrency_available.set() +# raise e +# instance.concurrency_available.set() +# return ret +# +# return wrapper def json_decode_object_hook( @@ -432,6 +282,18 @@ def json_loads(s: str, **kwargs): return json.loads(s, **kwargs) +def json_dump(obj, fp, *args, **kwargs): + if not kwargs.get('cls'): + kwargs.update(cls=ErepublikJSONEncoder) + return json.dump(obj, fp, *args, **kwargs) + + +def json_dumps(obj, *args, **kwargs): + if not kwargs.get('cls'): + kwargs.update(cls=ErepublikJSONEncoder) + return json.dumps(obj, *args, **kwargs) + + def b64json(obj: Union[Dict[str, Union[int, List[str]]], List[str]]): if isinstance(obj, list): return b64encode(json.dumps(obj).encode('utf-8')).decode('utf-8') @@ -444,3 +306,30 @@ def b64json(obj: Union[Dict[str, Union[int, List[str]]], List[str]]): from .classes import ErepublikException raise ErepublikException(f'Unhandled object type! obj is {type(obj)}') return b64encode(json.dumps(obj).encode('utf-8')).decode('utf-8') + + +class ErepublikJSONEncoder(json.JSONEncoder): + def default(self, o): + from erepublik.citizen import Citizen + if isinstance(o, Decimal): + return float(f"{o:.02f}") + elif isinstance(o, datetime.datetime): + return dict(__type__='datetime', date=o.strftime("%Y-%m-%d"), time=o.strftime("%H:%M:%S"), + tzinfo=str(o.tzinfo) if o.tzinfo else None) + elif isinstance(o, datetime.date): + return dict(__type__='date', date=o.strftime("%Y-%m-%d")) + elif isinstance(o, datetime.timedelta): + return dict(__type__='timedelta', days=o.days, seconds=o.seconds, + microseconds=o.microseconds, total_seconds=o.total_seconds()) + elif isinstance(o, Response): + return dict(headers=dict(o.__dict__['headers']), url=o.url, text=o.text, status_code=o.status_code) + elif hasattr(o, 'as_dict'): + return o.as_dict + elif isinstance(o, set): + return list(o) + elif isinstance(o, Citizen): + return o.to_json() + try: + return super().default(o) + except Exception as e: # noqa + return 'Object is not JSON serializable'