Commit db775b24 authored by Sascha Herzinger's avatar Sascha Herzinger
Browse files

Merge branch 'new-state' into 'master'

New state

See merge request Fractalis/fractalis!1
parents 6524d30c c5b76848
......@@ -4,7 +4,8 @@ import abc
import json
import re
import logging
from typing import List, Tuple
from uuid import UUID
from typing import List, Tuple, Union
from pandas import read_csv, DataFrame
from celery import Task
......@@ -126,7 +127,7 @@ class AnalyticTask(Task, metaclass=abc.ABCMeta):
value.endswith('$')
@staticmethod
def parse_value(value: str) -> Tuple[str, dict]:
def parse_value(value: str) -> Tuple[Union[str, None], dict]:
"""Extract data task id and filters from the string.
:param value: A string that contains a data task id.
:return: A tuple of id and filters to apply later.
......@@ -135,7 +136,7 @@ class AnalyticTask(Task, metaclass=abc.ABCMeta):
# noinspection PyBroadException
try:
value = json.loads(value)
data_task_id = value['id']
data_task_id = str(value['id'])
filters = value.get('filters')
except Exception:
logger.warning("Failed to parse value. "
......@@ -143,6 +144,12 @@ class AnalyticTask(Task, metaclass=abc.ABCMeta):
"but nothing else.")
data_task_id = value
filters = None
# noinspection PyBroadException
try:
data_task_id = str(UUID(data_task_id))
except Exception:
logger.warning("'{}' is no valid task id.".format(data_task_id))
data_task_id = None
return data_task_id, filters
def prepare_args(self, session_data_tasks: List[str],
......
......@@ -32,6 +32,8 @@ def create_data_task() -> Tuple[Response, int]:
server=payload['server'],
auth=payload['auth'])
task_ids = etl_handler.handle(descriptors=payload['descriptors'],
data_tasks=session['data_tasks'],
use_existing=False,
wait=wait)
session['data_tasks'] += task_ids
session['data_tasks'] = list(set(session['data_tasks']))
......@@ -72,17 +74,21 @@ def get_all_data() -> Tuple[Response, int]:
logger.debug("Received GET request on /data.")
wait = request.args.get('wait') == '1'
data_states = []
existing_data_tasks = []
for task_id in session['data_tasks']:
data_state = get_data_state_for_task_id(task_id, wait)
if data_state is None:
warning = "Data state with task_id '{}' expired. " \
"Discarding...".format(task_id)
logger.warning(warning)
continue
# remove internal information from response
del data_state['file_path']
del data_state['meta']
# add additional information to response
data_states.append(data_state)
existing_data_tasks.append(task_id)
session['data_tasks'] = existing_data_tasks
logger.debug("Data states collected. Sending response.")
return jsonify({'data_states': data_states}), 200
......
......@@ -5,10 +5,10 @@ import abc
import json
import logging
from uuid import uuid4
from typing import List
from typing import List, Union
from fractalis import app, redis
import manage
from fractalis import app, redis, celery
from fractalis.data.etl import ETL
......@@ -87,6 +87,7 @@ class ETLHandler(metaclass=abc.ABCMeta):
'file_path': file_path,
'label': self.make_label(descriptor),
'data_type': data_type,
'hash': self.descriptor_to_hash(descriptor),
'meta': {
'descriptor': descriptor,
}
......@@ -95,11 +96,75 @@ class ETLHandler(metaclass=abc.ABCMeta):
value=json.dumps(data_state),
time=app.config['FRACTALIS_CACHE_EXP'])
def handle(self, descriptors: List[dict], wait: bool = False) -> List[str]:
def descriptor_to_hash(self, descriptor: dict) -> int:
"""Compute hash for the given descriptor. Used to identify duplicates.
:param descriptor: ETL descriptor. Used to identify duplicates.
:return: Unique hash.
"""
string = '{}-{}-{}'.format(self._server,
self._handler,
str(descriptor))
hash_value = int.from_bytes(string.encode('utf-8'), 'little')
return hash_value
def find_duplicates(self, data_tasks: List[str],
descriptor: dict) -> List[str]:
"""Search for duplicates of the given descriptor and return a list
of associated task ids.
:param data_tasks: Limit duplicate search to.
:param descriptor: ETL descriptor. Used to identify duplicates.
:return: The list of duplicates.
"""
task_ids = []
hash_value = self.descriptor_to_hash(descriptor)
for task_id in data_tasks:
value = redis.get('data:{}'.format(task_id))
if value is None:
continue
data_state = json.loads(value)
if hash_value == data_state['hash']:
task_ids.append(task_id)
return task_ids
def remove_duplicates(self, data_tasks: List[str],
descriptor: dict) -> None:
"""Delete the duplicates of the given descriptor from redis and call
the janitor afterwards to cleanup orphaned files.
:param data_tasks: Limit duplicate search to.
:param descriptor: ETL descriptor. Used to identify duplicates.
"""
task_ids = self.find_duplicates(data_tasks, descriptor)
for task_id in task_ids:
redis.delete('data:{}'.format(task_id))
manage.janitor.delay()
def find_duplicate_task_id(self, data_tasks: List[str],
descriptor: dict) -> Union[str, None]:
"""Search for duplicates of the given descriptor and return their
task id if the state is SUBMITTED or SUCCESS, meaning the data are
reusable.
:param data_tasks: Limit search to this list.
:param descriptor: ETL descriptor. Used to identify duplicates.
:return: TaskID if valid duplicate has been found, None otherwise.
"""
task_ids = self.find_duplicates(data_tasks, descriptor)
for task_id in task_ids:
async_result = celery.AsyncResult(task_id)
if (async_result.state == 'SUBMITTED' or
async_result.state == 'SUCCESS'):
return task_id
return None
def handle(self, descriptors: List[dict], data_tasks: List[str],
use_existing: bool, wait: bool = False) -> List[str]:
"""Create instances of ETL for the given descriptors and submit them
(ETL implements celery.Task) to the broker. The task ids are returned
to keep track of them.
:param descriptors: A list of items describing the data to download.
:param data_tasks: Limit search for duplicates to this list.
:param use_existing: If a duplicate with state 'SUBMITTED' or 'SUCCESS'
already exists use it instead of starting a new ETL. If this is False
duplicates are deleted!
:param wait: Makes this method synchronous by waiting for the tasks to
return.
:return: The list of task ids for the submitted tasks.
......@@ -107,6 +172,14 @@ class ETLHandler(metaclass=abc.ABCMeta):
data_dir = os.path.join(app.config['FRACTALIS_TMP_DIR'], 'data')
task_ids = []
for descriptor in descriptors:
if use_existing:
task_id = self.find_duplicate_task_id(data_tasks, descriptor)
if task_id:
task_ids.append(task_id)
data_tasks.append(task_id)
continue
else:
self.remove_duplicates(data_tasks, descriptor)
task_id = str(uuid4())
file_path = os.path.join(data_dir, task_id)
etl = ETL.factory(handler=self._handler, descriptor=descriptor)
......@@ -118,9 +191,11 @@ class ETLHandler(metaclass=abc.ABCMeta):
async_result = etl.apply_async(kwargs=kwargs, task_id=task_id)
assert async_result.id == task_id
task_ids.append(task_id)
data_tasks.append(task_id)
if wait and async_result.state == 'SUBMITTED':
logger.debug("'wait' was set. Waiting for tasks to finish ...")
async_result.get(propagate=False)
task_ids = list(set(task_ids))
return task_ids
@staticmethod
......
......@@ -3,17 +3,18 @@
import re
import json
import logging
import ast
from uuid import UUID, uuid4
from typing import Tuple
from flask import Blueprint, jsonify, Response, request, session
from fractalis import redis
from fractalis import redis, celery
from fractalis.validator import validate_json, validate_schema
from fractalis.analytics.task import AnalyticTask
from fractalis.data.etlhandler import ETLHandler
from fractalis.data.controller import get_data_state_for_task_id
from fractalis.state.schema import request_state_access_schema
from fractalis.state.schema import request_state_access_schema, \
save_state_schema
state_blueprint = Blueprint('state_blueprint', __name__)
......@@ -22,37 +23,43 @@ logger = logging.getLogger(__name__)
@state_blueprint.route('', methods=['POST'])
@validate_json
@validate_schema(save_state_schema)
def save_state() -> Tuple[Response, int]:
"""Save given payload to redis, so it can be accessed later on.
:return: UUID linked to the saved state.
"""
logger.debug("Received POST request on /state.")
payload = request.get_json(force=True)
# check if task ids in payload are valid
matches = re.findall('\$.+?\$', str(payload))
if not matches:
state = str(payload['state'])
matches = re.findall('\$.+?\$', state)
task_ids = [AnalyticTask.parse_value(match)[0] for match in matches]
task_ids = [task_id for task_id in set(task_ids) if task_id is not None]
if not task_ids:
error = "This state cannot be saved because it contains no data " \
"task ids. These are used to verify access to the state and " \
"its potentially sensitive data."
logger.error(error)
return jsonify({'error': error}), 400
for match in matches:
task_id, _ = AnalyticTask.parse_value(match)
descriptors = []
for task_id in task_ids:
value = redis.get('data:{}'.format(task_id))
if value is None:
error = "Data task id is {} could not be found in redis. " \
"State cannot be saved".format(task_id)
logger.error(error)
return jsonify({'error': error}), 400
try:
json.loads(value)['meta']['descriptor']
except (ValueError, KeyError):
error = "Task with id {} was found in redis but it represents " \
"no valid data state. " \
"State cannot be saved.".format(task_id)
return jsonify({'error': error}), 400
data_state = json.loads(value)
descriptors.append(data_state['meta']['descriptor'])
assert len(task_ids) == len(descriptors)
meta_state = {
'state': ast.literal_eval(state),
'server': payload['server'],
'handler': payload['handler'],
'task_ids': task_ids,
'descriptors': descriptors
}
uuid = uuid4()
redis.set(name='state:{}'.format(uuid), value=json.dumps(payload))
redis.set(name='state:{}'.format(uuid), value=json.dumps(meta_state))
logger.debug("Successfully saved data to redis. Sending response.")
return jsonify({'state_id': uuid}), 201
......@@ -76,28 +83,18 @@ def request_state_access(state_id: UUID) -> Tuple[Response, int]:
error = "Could not find state associated with id {}".format(state_id)
logger.error(error)
return jsonify({'error': error}), 404
descriptors = []
matches = re.findall('\$.+?\$', str(json.loads(value)))
for match in matches:
task_id, _ = AnalyticTask.parse_value(match)
value = redis.get('data:{}'.format(task_id))
if value is None:
error = "The state with id {} exists, but one or more of the " \
"associated data task ids are missing. Hence this saved " \
"state is lost forever because access can no longer be " \
"verified. Deleting state..."
logger.error(error)
redis.delete('state:{}'.format(state_id))
return jsonify({'error': error}), 403
data_state = json.loads(value)
descriptors.append(data_state['meta']['descriptor'])
etl_handler = ETLHandler.factory(handler=payload['handler'],
server=payload['server'],
meta_state = json.loads(value)
etl_handler = ETLHandler.factory(handler=meta_state['handler'],
server=meta_state['server'],
auth=payload['auth'])
task_ids = etl_handler.handle(descriptors=descriptors, wait=wait)
task_ids = etl_handler.handle(descriptors=meta_state['descriptors'],
data_tasks=session['data_tasks'],
use_existing=True,
wait=wait)
session['data_tasks'] += task_ids
session['data_tasks'] = list(set(session['data_tasks']))
# if all task finish successfully we now that session has access to state
# if all tasks finish successfully we now that session has access to state
session['state_access'][state_id] = task_ids
logger.debug("Tasks successfully submitted. Sending response.")
return jsonify(''), 202
......@@ -112,25 +109,31 @@ def get_state_data(state_id: UUID) -> Tuple[Response, int]:
:return: Previously saved state.
"""
logger.debug("Received GET request on /state/<uuid:state_id>.")
wait = request.args.get('wait') == '1'
state_id = str(state_id)
if state_id not in session['state_access']:
value = redis.get('state:{}'.format(state_id))
if not value or state_id not in session['state_access']:
error = "Cannot get state. Make sure to submit a POST request " \
"to this very same URL containing credentials and server " \
"data to launch access verification. Only after that a GET " \
"request might or might not return you the saved state."
logger.error(error)
return jsonify({'error': error}), 404
meta_state = json.loads(value)
state = json.dumps(meta_state['state'])
for task_id in session['state_access'][state_id]:
data_state = get_data_state_for_task_id(task_id=task_id, wait=wait)
if data_state is not None and data_state['etl_state'] == 'SUBMITTED':
async_result = celery.AsyncResult(task_id)
if async_result.state == 'SUBMITTED':
return jsonify({'message': 'ETLs are still running.'}), 202
elif data_state is not None and data_state['etl_state'] == 'SUCCESS':
elif async_result.state == 'SUCCESS':
continue
else:
error = "One or more ETLs failed or has unknown status. " \
"Assuming no access to saved state."
logger.error(error)
return jsonify({'error': error}), 403
state = json.loads(redis.get('state:{}'.format(state_id)))
return jsonify({'state': state}), 200
# replace task ids in state with the ids of the freshly loaded data
for i, task_id in enumerate(meta_state['task_ids']):
state = re.sub(pattern=task_id,
repl=session['state_access'][state_id][i],
string=state)
return jsonify({'state': json.loads(state)}), 200
request_state_access_schema = {
save_state_schema = {
"type": "object",
"properties": {
"state": {"type": "object"},
"handler": {"type": "string"},
"server": {"type": "string"},
"server": {"type": "string"}
},
"required": ["handler", "server", "state"]
}
request_state_access_schema = {
"type": "object",
"properties": {
"auth": {
"type": "object",
"properties": {
......@@ -13,5 +22,5 @@ request_state_access_schema = {
"minProperties": 1
}
},
"required": ["handler", "server", "auth"]
"required": ["auth"]
}
......@@ -58,10 +58,12 @@ class TestData:
@pytest.fixture(scope='function', params=['small', 'big'])
def payload(self, request):
load = self.small_load() if request.param == 'small' \
else self.big_load()
return {'size': len(load['descriptors']),
'serialized': flask.json.dumps(load)}
def _payload():
load = self.small_load() if request.param == 'small' \
else self.big_load()
return {'size': len(load['descriptors']),
'serialized': flask.json.dumps(load)}
return _payload
@pytest.fixture(scope='function', params=['small', 'big'])
def faiload(self, request):
......@@ -144,15 +146,17 @@ class TestData:
assert bad_post().status_code == 400
def test_valid_response_on_post(self, test_client, payload):
rv = test_client.post('/data', data=payload['serialized'])
data = payload()
rv = test_client.post('/data', data=data['serialized'])
assert rv.status_code == 201
body = flask.json.loads(rv.get_data())
assert not body
def test_valid_redis_before_loaded_on_post(self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
keys = redis.keys('data:*')
assert len(keys) == payload['size']
assert len(keys) == data['size']
for key in keys:
value = redis.get(key)
data_state = json.loads(value)
......@@ -162,9 +166,10 @@ class TestData:
assert 'meta' in data_state
def test_valid_redis_after_loaded_on_post(self, test_client, payload):
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
keys = redis.keys('data:*')
assert len(keys) == payload['size']
assert len(keys) == data['size']
for key in keys:
value = redis.get(key)
data_state = json.loads(value)
......@@ -176,7 +181,8 @@ class TestData:
def test_valid_filesystem_before_loaded_on_post(
self, test_client, payload):
data_dir = os.path.join(app.config['FRACTALIS_TMP_DIR'], 'data')
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
if os.path.exists(data_dir):
assert len(os.listdir(data_dir)) == 0
keys = redis.keys('data:*')
......@@ -188,8 +194,9 @@ class TestData:
def test_valid_filesystem_after_loaded_on_post(
self, test_client, payload):
data_dir = os.path.join(app.config['FRACTALIS_TMP_DIR'], 'data')
test_client.post('/data?wait=1', data=payload['serialized'])
assert len(os.listdir(data_dir)) == payload['size']
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
assert len(os.listdir(data_dir)) == data['size']
for f in os.listdir(data_dir):
assert UUID(f)
keys = redis.keys('data:*')
......@@ -199,13 +206,15 @@ class TestData:
assert os.path.exists(data_state['file_path'])
def test_valid_session_on_post(self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
with test_client.session_transaction() as sess:
assert len(sess['data_tasks']) == payload['size']
assert len(sess['data_tasks']) == data['size']
def test_session_matched_redis_in_post_big_payload(
self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
with test_client.session_transaction() as sess:
for task_id in sess['data_tasks']:
assert redis.exists('data:{}'.format(task_id))
......@@ -213,14 +222,18 @@ class TestData:
def test_many_post_and_valid_state(self, test_client, payload):
requests = 5
data_dir = os.path.join(app.config['FRACTALIS_TMP_DIR'], 'data')
size = 0
for i in range(requests):
rv = test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
size += data['size']
rv = test_client.post('/data?wait=1', data=data['serialized'])
assert rv.status_code == 201
assert len(os.listdir(data_dir)) == requests * payload['size']
assert len(redis.keys('data:*')) == requests * payload['size']
assert len(os.listdir(data_dir)) == size
assert len(redis.keys('data:*')) == size
def test_valid_response_before_loaded_on_get(self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
rv = test_client.get('/data')
assert rv.status_code == 200
body = flask.json.loads(rv.get_data())
......@@ -233,7 +246,8 @@ class TestData:
assert 'task_id' in data_state
def test_valid_response_after_loaded_on_get(self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
rv = test_client.get('/data?wait=1')
assert rv.status_code == 200
body = flask.json.loads(rv.get_data())
......@@ -244,6 +258,23 @@ class TestData:
assert data_state['data_type'] == 'mock'
assert 'task_id' in data_state
def test_discard_expired_states(self, test_client):
data_state = {
'a': 'b',
'file_path': '',
'meta': ''
}
redis.set(name='data:456', value=json.dumps(data_state))
with test_client.session_transaction() as sess:
sess['data_tasks'] = ['123', '456']
rv = test_client.get('/data?wait=1')
body = flask.json.loads(rv.get_data())
assert rv.status_code == 200, body
assert len(body['data_states']) == 1
assert body['data_states'][0]['a'] == 'b'
with test_client.session_transaction() as sess:
sess['data_tasks'] = ['456']
def test_valid_response_if_failing_on_get(self, test_client, faiload):
test_client.post('/data', data=faiload['serialized'])
rv = test_client.get('/data?wait=1')
......@@ -258,7 +289,8 @@ class TestData:
def test_valid_state_for_finished_etl_on_delete(
self, test_client, payload):
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value)
......@@ -270,7 +302,8 @@ class TestData:
assert data_state['task_id'] not in sess['data_tasks']
def test_valid_state_for_running_etl_on_delete(self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value)
......@@ -294,7 +327,8 @@ class TestData:
assert data_state['task_id'] not in sess['data_tasks']
def test_403_if_no_auth_on_delete(self, test_client, payload):
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
with test_client.session_transaction() as sess:
sess['data_tasks'] = []
for key in redis.keys('data:*'):
......@@ -312,7 +346,8 @@ class TestData:
def test_valid_state_for_finished_etl_on_delete_all(
self, test_client, payload):
data_dir = os.path.join(app.config['FRACTALIS_TMP_DIR'], 'data')
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
test_client.delete('/data?wait=1')
assert not redis.keys('data:*')
assert len(os.listdir(data_dir)) == 0
......@@ -321,7 +356,8 @@ class TestData:
def test_encryption_works(self, test_client, payload):
app.config['FRACTALIS_ENCRYPT_CACHE'] = True
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
keys = redis.keys('data:*')
for key in keys:
value = redis.get(key)
......@@ -332,7 +368,8 @@ class TestData:
app.config['FRACTALIS_ENCRYPT_CACHE'] = False
def test_valid_response_before_loaded_on_meta(self, test_client, payload):
test_client.post('/data', data=payload['serialized'])
data = payload()
test_client.post('/data', data=data['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value)
......@@ -342,7 +379,8 @@ class TestData:
assert 'features' not in body['meta']
def test_valid_response_after_loaded_on_meta(self, test_client, payload):
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value)
......@@ -353,7 +391,8 @@ class TestData:
assert 'features' in body['meta']
def test_403_if_no_auth_on_get_meta(self, test_client, payload):
test_client.post('/data?wait=1', data=payload['serialized'])
data = payload()
test_client.post('/data?wait=1', data=data['serialized'])
with test_client.session_transaction() as sess: