diff --git a/metrik/tasks/base.py b/metrik/tasks/base.py index dcd4bb2..fe3705e 100644 --- a/metrik/tasks/base.py +++ b/metrik/tasks/base.py @@ -1,12 +1,16 @@ from __future__ import print_function import logging +import datetime +from time import sleep from luigi import Task from luigi.parameter import DateMinuteParameter, BoolParameter +from pymongo import MongoClient from metrik.targets.mongo import MongoTarget from metrik.targets.noop import NoOpTarget +from metrik.conf import MONGO_HOST, MONGO_PORT, MONGO_DATABASE class MongoCreateTask(Task): @@ -61,3 +65,59 @@ class MongoNoBackCreateTask(MongoCreateTask): # wish to persist for the future. if self.live: return super(MongoNoBackCreateTask, self).run() + + +class MongoRateLimit(object): + rate_limit_collection = 'rate_limit' + + def __init__(self, service, limit, interval, max_tries=5, backoff=.5): + """ + + :param present: + :type present: datetime.datetime + :param service: + :param limit: + :param interval: + :type interval: datetime.timedelta + :param max_tries: + :param backoff: + """ + self.service = service + self.limit = limit + self.interval = interval + self.max_tries = max_tries + self.backoff = backoff + self.db = MongoClient(host=MONGO_HOST, port=MONGO_PORT)[MONGO_DATABASE] + + def get_present(self): + return datetime.datetime.now() + + def query_locks(self, present): + return self.db[self.rate_limit_collection].find( + {'_created_at': {'$gt': present - self.interval}, + 'service': self.service}).count() + + def save_lock(self, present): + self.db[self.rate_limit_collection].save({ + '_created_at': present, 'service': self.service + }) + + def sleep_until(self, present): + future_time = present + self.interval * self.backoff + return (future_time - present).total_seconds() + + def acquire_lock(self): + num_tries = 0 + while num_tries < self.max_tries: + num_tries += 1 + num_locks = self.query_locks(self.get_present()) + if num_locks < self.limit: + self.save_lock(self.get_present()) + return True + elif num_tries < self.max_tries: + sleep_amount = self.sleep_until(self.get_present()) + sleep(sleep_amount) + + return False + + diff --git a/test/mongo_test.py b/test/mongo_test.py index 7419463..ad0f59a 100644 --- a/test/mongo_test.py +++ b/test/mongo_test.py @@ -6,10 +6,13 @@ from metrik.targets.mongo import MongoTarget class MongoTest(TestCase): + def setUp(self): + self.client = MongoClient(MONGO_HOST, MONGO_PORT) + self.db = self.client[MONGO_DATABASE] + def tearDown(self): super(MongoTest, self).tearDown() - client = MongoClient(MONGO_HOST, MONGO_PORT) - client.drop_database(MONGO_DATABASE) + self.client.drop_database(MONGO_DATABASE) class MongoTestTest(MongoTest): diff --git a/test/tasks/test_base.py b/test/tasks/test_base.py index 77833a6..9e5b17d 100644 --- a/test/tasks/test_base.py +++ b/test/tasks/test_base.py @@ -1,11 +1,94 @@ from unittest import TestCase -from datetime import datetime +from datetime import datetime, timedelta -from metrik.tasks.base import MongoNoBackCreateTask +from metrik.tasks.base import MongoNoBackCreateTask, MongoRateLimit +from test.mongo_test import MongoTest class BaseTaskTest(TestCase): - def test_mongo_no_back_live_false(self): + # Test that default for `live` parameter is False task = MongoNoBackCreateTask(current_datetime=datetime.now()) - assert not task.live \ No newline at end of file + assert not task.live + + +class RateLimitTest(MongoTest): + def test_save_creates_record(self): + service = 'testing_ratelimit' + assert self.db[MongoRateLimit.rate_limit_collection].count() == 0 + + present = datetime.now() + onesec_back = present - timedelta(seconds=1) + ratelimit = MongoRateLimit( + service, 1, timedelta(seconds=1) + ) + assert ratelimit.query_locks(onesec_back) == 0 + + ratelimit.save_lock(present) + assert self.db[service].count() == 1 + assert ratelimit.query_locks(onesec_back) == 1 + + def test_save_creates_correct_service(self): + service_1 = 'testing_ratelimit_1' + service_2 = 'testing_ratelimit_2' + + ratelimit1 = MongoRateLimit( + service_1, 1, timedelta(seconds=1) + ) + ratelimit2 = MongoRateLimit( + service_2, 1, timedelta(seconds=1) + ) + + present = datetime.now() + assert self.db[MongoRateLimit.rate_limit_collection].count() == 0 + assert ratelimit1.query_locks(present) == 0 + assert ratelimit2.query_locks(present) == 0 + + ratelimit1.save_lock(present) + assert self.db[MongoRateLimit.rate_limit_collection].count() == 1 + assert ratelimit1.query_locks(present) == 1 + assert ratelimit2.query_locks(present) == 0 + + def test_acquire_lock_fails(self): + service = 'testing_ratelimit' + + # The first scenario is as follows: + # We try to acquire a lock with 1 try, backoff is 10. + # We are checking for locks up to 1 second ago, and there + # is a lock in the database from a half-second ago. Thus, + # we should fail immediately since we did not acquire the + # lock and are only allowed one try. + # Ultimately, we are testing that the 'fail immediately' + # switch gets triggered correctly + ratelimit = MongoRateLimit( + service, 1, timedelta(seconds=1), max_tries=1, backoff=10 + ) + + start = datetime.now() + ratelimit.save_lock(start) + did_acquire = ratelimit.acquire_lock() + end = datetime.now() + assert not did_acquire + assert (end - start).total_seconds() < 1 + + def test_acquire_lock_succeeds(self): + service = 'testing_ratelimit' + + # The first scenario is as follows: + # We try to acquire a lock with two tries, backoff is 1. + # We put a single lock in initially (a half second in the past), + # thus when we try to acquire on the first try, we should fail. + # However, the backoff should kick in, and we acquire successfully + # on the second try + ratelimit = MongoRateLimit( + service, 1, timedelta(seconds=1), max_tries=2, backoff=1 + ) + + start = datetime.now() + ratelimit.save_lock(start - timedelta(seconds=.5)) + did_acquire = ratelimit.acquire_lock() + end = datetime.now() + # Check that we acquired the lock + assert did_acquire + # Check that we only used one backoff period + assert (end - start).total_seconds() < 2 \ No newline at end of file