diff --git a/octodns/provider/ns1.py b/octodns/provider/ns1.py index 8acbd53..da2d64a 100644 --- a/octodns/provider/ns1.py +++ b/octodns/provider/ns1.py @@ -22,51 +22,44 @@ from .base import BaseProvider class Ns1Client(object): log = getLogger('NS1Client') - def __init__(self, api_key, retry_delay=1): - self.retry_delay = retry_delay + def __init__(self, api_key, retry_count=4): + self.retry_count = retry_count client = NS1(apiKey=api_key) self._records = client.records() self._zones = client.zones() + def _try(self, method, *args, **kwargs): + tries = self.retry_count + while tries: + try: + return method(*args, **kwargs) + except RateLimitException as e: + period = float(e.period) + self.log.warn('rate limit encountered, pausing ' + 'for %ds and trying again, %d remaining', + period, tries) + sleep(period) + tries -= 1 + raise + def zones_retrieve(self, name): - return self._zones.retrieve(name) + return self._try(self._zones.retrieve, name) def zones_create(self, name): - return self._zones.create(name) + return self._try(self._zones.create, name) def records_retrieve(self, zone, domain, _type): - return self._records.retrieve(zone, domain, _type) + return self._try(self._records.retrieve, zone, domain, _type) def records_create(self, zone, domain, _type, **params): - try: - return self._records.create(zone, domain, _type, **params) - except RateLimitException as e: - period = float(e.period) - self.log.warn('_apply_Create: rate limit encountered, pausing ' - 'for %ds and trying again', period) - sleep(period) - return self._records.create(zone, domain, _type, **params) + return self._try(self._records.create, zone, domain, _type, **params) def records_update(self, zone, domain, _type, **params): - try: - return self._records.update(zone, domain, _type, **params) - except RateLimitException as e: - period = float(e.period) - self.log.warn('_apply_Update: rate limit encountered, pausing ' - 'for %ds and trying again', period) - sleep(period) - return self._records.update(zone, domain, _type, **params) + return self._try(self._records.update, zone, domain, _type, **params) def records_delete(self, zone, domain, _type): - try: - return self._records.delete(zone, domain, _type) - except RateLimitException as e: - period = float(e.period) - self.log.warn('_apply_Delete: rate limit encountered, pausing ' - 'for %ds and trying again', period) - sleep(period) - return self._records.delete(zone, domain, _type) + return self._try(self._records.delete, zone, domain, _type) class Ns1Provider(BaseProvider): @@ -84,12 +77,12 @@ class Ns1Provider(BaseProvider): ZONE_NOT_FOUND_MESSAGE = 'server error: zone not found' - def __init__(self, id, api_key, retry_delay=1, *args, **kwargs): + def __init__(self, id, api_key, retry_count=4, *args, **kwargs): self.log = getLogger('Ns1Provider[{}]'.format(id)) - self.log.debug('__init__: id=%s, api_key=***, retry_delay=%d', id, - retry_delay) + self.log.debug('__init__: id=%s, api_key=***, retry_count=%d', id, + retry_count) super(Ns1Provider, self).__init__(id, *args, **kwargs) - self._client = Ns1Client(api_key, retry_delay) + self._client = Ns1Client(api_key, retry_count) def _data_for_A(self, _type, record): # record meta (which would include geo information is only diff --git a/tests/test_octodns_provider_ns1.py b/tests/test_octodns_provider_ns1.py index 0f23222..0743943 100644 --- a/tests/test_octodns_provider_ns1.py +++ b/tests/test_octodns_provider_ns1.py @@ -9,10 +9,11 @@ from collections import defaultdict from mock import call, patch from ns1.rest.errors import AuthException, RateLimitException, \ ResourceException +from six import text_type from unittest import TestCase from octodns.record import Delete, Record, Update -from octodns.provider.ns1 import Ns1Provider +from octodns.provider.ns1 import Ns1Client, Ns1Provider from octodns.zone import Zone @@ -497,3 +498,46 @@ class TestNs1Provider(TestCase): } self.assertEqual(b_expected, provider._data_for_CNAME(b_record['type'], b_record)) + + +class TestNs1Client(TestCase): + + @patch('ns1.rest.zones.Zones.retrieve') + def test_retry_behavior(self, zone_retrieve_mock): + client = Ns1Client('dummy-key') + + # No retry required, just calls and is returned + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = ['foo'] + self.assertEquals('foo', client.zones_retrieve('unit.tests')) + zone_retrieve_mock.assert_has_calls([call('unit.tests')]) + + # One retry required + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = [ + RateLimitException('boo', period=0), + 'foo' + ] + self.assertEquals('foo', client.zones_retrieve('unit.tests')) + zone_retrieve_mock.assert_has_calls([call('unit.tests')]) + + # Two retries required + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = [ + RateLimitException('boo', period=0), + 'foo' + ] + self.assertEquals('foo', client.zones_retrieve('unit.tests')) + zone_retrieve_mock.assert_has_calls([call('unit.tests')]) + + # Exhaust our retries + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = [ + RateLimitException('first', period=0), + RateLimitException('boo', period=0), + RateLimitException('boo', period=0), + RateLimitException('last', period=0), + ] + with self.assertRaises(RateLimitException) as ctx: + client.zones_retrieve('unit.tests') + self.assertEquals('last', text_type(ctx.exception))