RxNettyHttpClient.java 12.9 KB
Newer Older
1 2
package com.netflix.client.netty.http;

3
import io.netty.bootstrap.Bootstrap;
4 5 6 7 8 9 10 11 12 13 14
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.QueryStringEncoder;

15 16 17 18
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
19
import java.util.Map.Entry;
20 21
import java.util.concurrent.TimeUnit;

22 23 24 25 26 27 28 29 30 31
import rx.Observable;
import rx.netty.protocol.http.HttpProtocolHandler;
import rx.netty.protocol.http.HttpProtocolHandlerAdapter;
import rx.netty.protocol.http.Message;
import rx.netty.protocol.http.ObservableHttpClient;
import rx.netty.protocol.http.ObservableHttpResponse;
import rx.netty.protocol.http.SelfRemovingResponseTimeoutHandler;
import rx.netty.protocol.http.ValidatedFullHttpRequest;
import rx.util.functions.Func1;

32 33 34 35
import com.netflix.client.ClientException;
import com.netflix.client.config.CommonClientConfigKey;
import com.netflix.client.config.DefaultClientConfigImpl;
import com.netflix.client.config.IClientConfig;
36
import com.netflix.client.config.IClientConfigKey;
37 38
import com.netflix.client.http.HttpRequest;
import com.netflix.client.http.HttpResponse;
39
import com.netflix.client.http.UnexpectedHttpResponseException;
40 41
import com.netflix.client.http.HttpRequest.Builder;
import com.netflix.client.http.HttpRequest.Verb;
42
import com.netflix.serialization.Deserializer;
43
import com.netflix.serialization.HttpSerializationContext;
44 45
import com.netflix.serialization.JacksonSerializationFactory;
import com.netflix.serialization.SerializationFactory;
46
import com.netflix.serialization.SerializationUtils;
47 48 49 50 51 52
import com.netflix.serialization.Serializer;
import com.netflix.serialization.TypeDef;

public class RxNettyHttpClient {

    private ObservableHttpClient observableClient;
53
    private SerializationFactory<HttpSerializationContext> serializationFactory;
54
    private int connectTimeout;
55
    private int readTimeout;
56 57 58
    private IClientConfig config;
    
    public RxNettyHttpClient() {
59 60
        this(DefaultClientConfigImpl.getClientConfigWithDefaultValues(), new JacksonSerializationFactory(), 
                new Bootstrap().group(new NioEventLoopGroup()));        
61 62 63
    }
    
    public RxNettyHttpClient(IClientConfig config) {
64
        this(config, new JacksonSerializationFactory(), new Bootstrap().group(new NioEventLoopGroup()));
65 66 67
    }
    
    
68 69
    public RxNettyHttpClient(IClientConfig config, SerializationFactory<HttpSerializationContext> serializationFactory, 
            Bootstrap bootStrap) {
70 71
        this.config = config;
        this.connectTimeout = config.getPropertyAsInteger(CommonClientConfigKey.ConnectTimeout, DefaultClientConfigImpl.DEFAULT_CONNECT_TIMEOUT);
72
        this.readTimeout = config.getPropertyAsInteger(CommonClientConfigKey.ReadTimeout, DefaultClientConfigImpl.DEFAULT_READ_TIMEOUT);  
73
        this.serializationFactory = serializationFactory;
74 75
        this.observableClient = ObservableHttpClient.newBuilder()
                .withChannelOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
76
                .build(bootStrap.group());
77 78
    }
    
79 80
    private class SingleEntityHandler<T> extends HttpProtocolHandlerAdapter<T> {
        private HttpRequest request;
81
        private TypeDef<T> typeDef;
82
        
83 84
        private SingleEntityHandler(HttpRequest request, SerializationFactory<HttpSerializationContext> serializationFactory, TypeDef<T> typeDef) {
            this.request = request;
85
            this.typeDef = typeDef;
86 87 88 89
        }

        @Override
        public void configure(ChannelPipeline pipeline) {
90 91
            int timeout = readTimeout;
            if (request.getOverrideConfig() != null) {
92
                Integer overrideTimeout = request.getOverrideConfig().getPropertyWithType(CommonClientConfigKey.ReadTimeout);
93 94 95 96
                if (overrideTimeout != null) {
                    timeout = overrideTimeout.intValue();
                }
            }
97
            pipeline.addAfter("http-response-decoder", "http-aggregator", new HttpObjectAggregator(Integer.MAX_VALUE));
98
            pipeline.addAfter("http-aggregator", SelfRemovingResponseTimeoutHandler.NAME, new SelfRemovingResponseTimeoutHandler(timeout, TimeUnit.MILLISECONDS));
99
            pipeline.addAfter(SelfRemovingResponseTimeoutHandler.NAME, "entity-decoder", new HttpEntityDecoder<T>(serializationFactory, request, typeDef));
100 101
        }
    }
102
        
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    private ValidatedFullHttpRequest getHttpRequest(HttpRequest request) throws ClientException {
        ValidatedFullHttpRequest r = null;
        Object entity = request.getEntity();
        String uri = request.getUri().toString();
        if (request.getQueryParams() != null) {
            QueryStringEncoder encoder = new QueryStringEncoder(uri);
            for (Map.Entry<String, Collection<String>> entry: request.getQueryParams().entrySet()) {
                String name = entry.getKey();
                Collection<String> values = entry.getValue();
                for (String value: values) {
                    encoder.addParam(name, value);
                }
            }
            uri = encoder.toString();
        }
        if (entity != null) {
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
            ByteBuf buf = null;
            int contentLength = -1;
            if (entity instanceof ByteBuf) {
                buf = (ByteBuf) entity;
            } else {
                Serializer serializer = null;
                if (request.getOverrideConfig() != null) {
                    serializer = request.getOverrideConfig().getPropertyWithType(CommonClientConfigKey.Serializer);
                }
                if (serializer == null) {
                    HttpSerializationContext key = new HttpSerializationContext(request.getHttpHeaders(), request.getUri());
                    serializer = serializationFactory.getSerializer(key, request.getEntityType());
                }
                if (serializer == null) {
                    throw new ClientException("Unable to find serializer");
                }
                ByteArrayOutputStream bout = new ByteArrayOutputStream();
                try {
                    serializer.serialize(bout, entity, request.getEntityType());
                } catch (IOException e) {
                    throw new ClientException("Error serializing entity in request", e);
                }
                byte[] content = bout.toByteArray();
                buf = Unpooled.wrappedBuffer(content);
                contentLength = content.length;
144 145
            }
            r = new ValidatedFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.valueOf(request.getVerb().name()), uri, buf);
146 147 148
            if (contentLength >= 0) {
                r.headers().set(HttpHeaders.Names.CONTENT_LENGTH, contentLength);
            }
149 150 151
        } else {
            r = new ValidatedFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.valueOf(request.getVerb().name()), uri);
        }
152 153 154
        if (request.getHttpHeaders() != null) {
            for (Entry<String, String> header: request.getHttpHeaders().getAllHeaders()) {
                r.headers().set(header.getKey(), header.getValue());
155 156 157 158 159 160 161 162 163 164 165 166
            }
        }
        if (request.getUri().getHost() != null) {
            r.headers().set(HttpHeaders.Names.HOST, request.getUri().getHost());
        }
        return r;
    }

    public IClientConfig getConfig() {
        return config;
    }

167 168 169 170
    public SerializationFactory<HttpSerializationContext> getSerializationFactory() {
        return serializationFactory;
    }
    
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    public <T> Observable<ServerSentEvent<T>> createServerSentEventEntityObservable(final HttpRequest request, final TypeDef<T> typeDef) {
        return createServerSentEventObservable(request)
                .flatMap(new Func1<ObservableHttpResponse<Message>, Observable<ServerSentEvent<T>>>() {
                    @Override
                    public Observable<ServerSentEvent<T>> call(
                            ObservableHttpResponse<Message> t1) {
                        io.netty.handler.codec.http.HttpResponse response = t1.response();
                        if (response.getStatus().code() != 200) {
                            return Observable.<ServerSentEvent<T>>error(new UnexpectedHttpResponseException(
                                    new NettyHttpResponse(response, null, serializationFactory, request)));
                        }
                        final Deserializer<T> deserializer = SerializationUtils.getDeserializer(request, new NettyHttpHeaders(response), typeDef, serializationFactory);
                        return t1.content().map(new Func1<Message, ServerSentEvent<T>>() {
                            @Override
                            public ServerSentEvent<T> call(Message t1) {
                                try {
                                    return new ServerSentEvent<T>(t1.getEventId(), t1.getEventName(), 
                                            SerializationUtils.deserializeFromString(deserializer, t1.getEventData(), typeDef));
                                } catch (IOException e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        });
                    }
                });
    }
197

198 199 200 201 202 203 204 205 206 207 208
    public Observable<ObservableHttpResponse<Message>> createServerSentEventObservable(final HttpRequest request) {
        return createObservableHttpResponse(request, HttpProtocolHandler.SSE_HANDLER);
    }
    
    public Observable<HttpResponse> createFullHttpResponseObservable(final HttpRequest request) {
        return createEntityObservable(request, TypeDef.fromClass(HttpResponse.class));
    }
    
    
    public <T> Observable<T> createEntityObservable(final HttpRequest request, TypeDef<T> typeDef) {
        Observable<ObservableHttpResponse<T>> observableHttpResponse = createObservableHttpResponse(request, new SingleEntityHandler<T>(request, serializationFactory, typeDef));
209 210 211 212 213 214 215
        return observableHttpResponse.flatMap(new Func1<ObservableHttpResponse<T>, Observable<T>>() {
            @Override
            public Observable<T> call(ObservableHttpResponse<T> t1) {
                return t1.content();
            }
        });
    }    
216
    
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    private <T> T getProperty(IClientConfigKey<T> key, IClientConfig overrideConfig) {
        T value = null;
        if (overrideConfig != null) {
            value = overrideConfig.getPropertyWithType(key);
        }
        if (value == null) {
            value = getConfig().getPropertyWithType(key);
        }
        return value;
    }
        
    private static class RedirectException extends RuntimeException {
        public RedirectException(String message) {
            super(message);
        }
    }
233
    
234
    public <T> Observable<ObservableHttpResponse<T>> createObservableHttpResponse(final HttpRequest request, final HttpProtocolHandler<T> protocolHandler) {
235 236 237 238 239 240
        ValidatedFullHttpRequest r = null;
        try {
            r = getHttpRequest(request);
        } catch (final Exception e) {
            return Observable.error(e);
        }
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
        final Observable<ObservableHttpResponse<T>> observable = observableClient.execute(r, protocolHandler);
        Boolean followRedirect = getProperty(CommonClientConfigKey.FollowRedirects, request.getOverrideConfig());
        if (followRedirect != null && followRedirect.booleanValue()) {
            return observable.flatMap(new Func1<ObservableHttpResponse<T>, Observable<ObservableHttpResponse<T>>>() {
                @Override
                public Observable<ObservableHttpResponse<T>> call(
                        ObservableHttpResponse<T> t1) {
                    int statusCode = t1.response().getStatus().code();
                    switch (statusCode) {
                    case 301:
                    case 302:
                    case 303:
                    case 307:
                    case 308:
                        String location = t1.response().headers().get("Location");
                        if (location == null) {
                            return Observable.error(new Exception("Location header is not set in the redirect response"));
                        } 
                        
                        Builder builder = HttpRequest.newBuilder(request).uri(location);
                        if (statusCode == 303) {
                            // according to the spec, this must be done with GET
                            builder.verb(Verb.GET);
                        }
                        Observable<ObservableHttpResponse<T>> newObservable = createObservableHttpResponse(builder.build(), protocolHandler);
                        return newObservable;
                    default: break;
                    }
                    return Observable.from(t1);
                }
            });
        } else {
            return observable;
        }
275
    }    
276
}