Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/java-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
--file duo-client/pom.xml
- name: Test with Maven
run: >
mvn test
--batch-mode
mvn test
--batch-mode
-file duo-client/pom.xml
- name: Lint with checkstyle
run: mvn checkstyle:check
38 changes: 37 additions & 1 deletion duo-client/src/main/java/com/duosecurity/client/Http.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class Http {
private Headers.Builder headers;
private SortedMap<String, Object> params = new TreeMap<String, Object>();
protected int sigVersion = 5;
private long maxBackoffMs = MAX_BACKOFF_MS;
private Random random = new Random();
private OkHttpClient httpClient;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
Expand Down Expand Up @@ -314,7 +315,7 @@ private Response executeRequest(Request request) throws Exception {
long backoffMs = INITIAL_BACKOFF_MS;
while (true) {
Response response = httpClient.newCall(request).execute();
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > MAX_BACKOFF_MS) {
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > maxBackoffMs) {
return response;
}

Expand All @@ -327,6 +328,13 @@ protected void sleep(long ms) throws Exception {
Thread.sleep(ms);
}

protected void setMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;
}

public void signRequest(String ikey, String skey)
throws UnsupportedEncodingException {
signRequest(ikey, skey, sigVersion);
Expand Down Expand Up @@ -529,6 +537,7 @@ protected abstract static class ClientBuilder<T extends Http> {
private final String uri;

private int timeout = DEFAULT_TIMEOUT_SECS;
private long maxBackoffMs = MAX_BACKOFF_MS;
private String[] caCerts = null;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
private Map<String, String> headers = new HashMap<String, String>();
Expand Down Expand Up @@ -558,6 +567,32 @@ public ClientBuilder<T> useTimeout(int timeout) {
return this;
}

/**
* Set the maximum base backoff time in milliseconds for rate limit (429) retries.
* When a request receives a 429 response, the client retries with exponential
* backoff until the base backoff exceeds this threshold. Note that actual sleep
* time includes up to 1000ms of random jitter on top of the base backoff.
* Setting to 0 disables retries (as does any value below the initial
* backoff of 1000ms). Default is 32000ms (32 seconds).
*
* <p>Note: When using method chaining from outside this package (e.g. with
* {@code AuthBuilder} or {@code AdminBuilder}), assign the builder to a variable
* and call methods separately, then call {@code build()}. This is a known
* limitation of all {@code ClientBuilder} methods.
*
* @param maxBackoffMs the maximum base backoff in milliseconds (must be >= 0)
* @return the Builder
* @throws IllegalArgumentException if maxBackoffMs is negative
*/
public ClientBuilder<T> useMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;

return this;
}

/**
* Provide custom CA certificates for certificate pinning.
*
Expand Down Expand Up @@ -604,6 +639,7 @@ public ClientBuilder<T> addHeader(String name, String value) {
*/
public T build() {
T duoClient = createClient(method, host, uri, timeout);
duoClient.setMaxBackoffMs(maxBackoffMs);
if (caCerts != null) {
duoClient.useCustomCertificates(caCerts);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ public class HttpRateLimitRetryTest {

private final int RANDOM_INT = 234;

@Before
public void before() throws Exception {
http = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
http = Mockito.spy(http);
private void setupHttp(Http client) throws Exception {
http = Mockito.spy(client);

Field httpClientField = Http.class.getDeclaredField("httpClient");
httpClientField.setAccessible(true);
Expand All @@ -39,6 +37,12 @@ public void before() throws Exception {
Mockito.doNothing().when(http).sleep(Mockito.any(Long.class));
}

@Before
public void before() throws Exception {
Http client = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
setupHttp(client);
}

@Test
public void testSingleRateLimitRetry() throws Exception {
final List<Response> responses = new ArrayList<Response>();
Expand Down Expand Up @@ -128,4 +132,98 @@ public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
assertEquals(16000L + RANDOM_INT, (long) sleepTimes.get(4));
assertEquals(32000L + RANDOM_INT, (long) sleepTimes.get(5));
}

@Test
public void testMaxBackoffZeroDisablesRetry() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(0)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

Response actualRes = http.executeHttpRequest();
assertEquals(1, responses.size());
assertEquals(429, actualRes.code());

// Verify no sleep was called
Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class));
}

@Test
public void testMaxBackoffCustomLimit() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(4000)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

// With maxBackoff=4000, retries at 1000, 2000, 4000, then 8000 > 4000 exits
// That's 4 total requests (1 initial + 3 retries)
Response actualRes = http.executeHttpRequest();
assertEquals(4, responses.size());
assertEquals(429, actualRes.code());

ArgumentCaptor<Long> sleepCapture = ArgumentCaptor.forClass(Long.class);
Mockito.verify(http, Mockito.times(3)).sleep(sleepCapture.capture());
List<Long> sleepTimes = sleepCapture.getAllValues();
assertEquals(1000L + RANDOM_INT, (long) sleepTimes.get(0));
assertEquals(2000L + RANDOM_INT, (long) sleepTimes.get(1));
assertEquals(4000L + RANDOM_INT, (long) sleepTimes.get(2));
}

@Test
public void testDefaultMaxBackoffIsUsedWhenNotSpecified() throws Exception {
Http defaultHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();

Field maxBackoffField = Http.class.getDeclaredField("maxBackoffMs");
maxBackoffField.setAccessible(true);
long actualMaxBackoff = (long) maxBackoffField.get(defaultHttp);

assertEquals(Http.MAX_BACKOFF_MS, actualMaxBackoff);
}

@Test(expected = IllegalArgumentException.class)
public void testMaxBackoffNegativeThrows() {
new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(-1)
.build();
}
}
Loading