2016-07-11 64 views
0

我管線看起來如下的Netty:設置Cookie中的WebSocket握手

ChannelPipeline pipeline = ch.pipeline(); 
pipeline.addLast(new HttpServerCodec()); 
pipeline.addLast(new HttpObjectAggregator(65536)); 
pipeline.addLast(new WebSocketServerProtocolHandler(WEBSOCKET_PATH, null, true)); 

我想在握手的響應添加Set-Cookie HTTP標頭。這是RFC6455

部分來自服務器的握手看起來如下:

Connection:upgrade 
    Sec-Websocket-Accept:T1UGQ4HhT3dvLNq5Yi+i/gfASi8= 
    Upgrade:websocket 
    Set-Cookie: ccc=22; path=/; HttpOnly 

一個無序的頭字段來在這兩種情況下, 領先行之後。這些頭域的含義在本文檔的第4節 中指定。額外的頭域也可以存在,例如 作爲cookie [RFC6265]。

回答

0

我沒有找到一個好方法。最後我通過反射來調用私有方法。

Netty的4.1.2.Final

首先找到WebSocketServerProtocolHandshakeHandler類的源代碼。這個類是非公開的,所以請複製一下這個類並根據它進行修改。

class CustomWebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter { 

    private final String websocketPath; 
    private final String subprotocols; 
    private final boolean allowExtensions; 
    private final int maxFramePayloadSize; 
    private final boolean allowMaskMismatch; 
    static final MethodHandle setHandshakerMethod = getSetHandshakerMethod(); 
    static final MethodHandle forbiddenHttpRequestResponderMethod = getForbiddenHttpRequestResponderMethod(); 

    static MethodHandle getSetHandshakerMethod(){ 
     try { 
      Method method = WebSocketServerProtocolHandler.class.getDeclaredMethod("setHandshaker" 
        , Channel.class 
        , WebSocketServerHandshaker.class 
        ); 
      method.setAccessible(true); 

      return MethodHandles.lookup().unreflect(method); 
     } catch (Throwable e) { 
      // Should never happen 
      e.printStackTrace(); 
      System.exit(5); 
      return null; 
     } 
    } 

    static MethodHandle getForbiddenHttpRequestResponderMethod(){ 
     try { 
      Method method = WebSocketServerProtocolHandler.class.getDeclaredMethod("forbiddenHttpRequestResponder"); 
      method.setAccessible(true); 

      return MethodHandles.lookup().unreflect(method); 
     } catch (Throwable e) { 
      // Should never happen 
      e.printStackTrace(); 
      System.exit(6); 
      return null; 
     } 
    } 

    public CustomWebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, 
      boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { 
     this.websocketPath = websocketPath; 
     this.subprotocols = subprotocols; 
     this.allowExtensions = allowExtensions; 
     maxFramePayloadSize = maxFrameSize; 
     this.allowMaskMismatch = allowMaskMismatch; 
    } 

    @Override 
    public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { 
     FullHttpRequest req = (FullHttpRequest) msg; 
     if (!websocketPath.equals(req.uri())) { 
      ctx.fireChannelRead(msg); 
      return; 
     } 

     try { 
      if (req.method() != GET) { 
       sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); 
       return; 
      } 

      final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
        getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols, 
          allowExtensions, maxFramePayloadSize, allowMaskMismatch); 
      final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); 
      if (handshaker == null) { 
       WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); 
      } else { 



       Channel channel = ctx.channel(); 
       final ChannelFuture handshakeFuture = handshaker.handshake(channel, req, getResponseHeaders(req), channel.newPromise()); 

       handshakeFuture.addListener(new ChannelFutureListener() { 
        @Override 
        public void operationComplete(ChannelFuture future) throws Exception { 
         if (!future.isSuccess()) { 
          ctx.fireExceptionCaught(future.cause()); 
         } else { 
          ctx.fireUserEventTriggered(
            WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE); 
         } 
        } 
       }); 


       try { 
        setHandshakerMethod.invokeExact(ctx.channel(), handshaker); 

        ChannelHandler handler = (ChannelHandler)forbiddenHttpRequestResponderMethod.invokeExact(); 
        ctx.pipeline().replace(this, "WS403Responder", handler); 

       } catch (Throwable e) { 
        // Should never happen 
        e.printStackTrace(); 
        System.exit(7); 
       } 
      } 
     } finally { 
      req.release(); 
     } 
    } 

    private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { 
     ChannelFuture f = ctx.channel().writeAndFlush(res); 
     if (!isKeepAlive(req) || res.status().code() != 200) { 
      f.addListener(ChannelFutureListener.CLOSE); 
     } 
    } 

    private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) { 
     String protocol = "ws"; 
     if (cp.get(SslHandler.class) != null) { 
      // SSL in use so use Secure WebSockets 
      protocol = "wss"; 
     } 
     return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path; 
    } 

    private static HttpHeaders getResponseHeaders(FullHttpRequest req){ 
     final String cookieName = "cid";   
     final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders(); 

     String connectionID = null; 
     String cookieString = req.headers().get(HttpHeaderNames.COOKIE); 
     if(cookieString != null && cookieString.length() > 0) 
     { 
      Set<Cookie> cookies = ServerCookieDecoder.LAX.decode(cookieString); 
      for (Cookie cookie : cookies) { 
       if(cookieName.equalsIgnoreCase(cookie.name())){ 
        connectionID = cookie.value(); 
        break; 
       } 
      } 
     } 
     if(connectionID == null || connectionID.length() < 16 || connectionID.length() > 50){ 
      connectionID = UUID.randomUUID().toString().replaceAll("-", ""); 
     } 


     DefaultCookie cookie = new DefaultCookie("cid", connectionID); 
     cookie.setPath("/"); 
     cookie.setHttpOnly(true); 
     cookie.setSecure(false); 

     httpHeaders.add(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.LAX.encode(cookie)); 
     return httpHeaders; 
    } 
} 

然後添加一個新的類繼承自WebSocketServerProtocolHandler

class CustomWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler { 

    private final String websocketPath; 
    private final String subprotocols; 
    private final boolean allowExtensions; 
    private final int maxFramePayloadLength; 
    private final boolean allowMaskMismatch; 


    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, 
      boolean allowExtensions) { 
     this(websocketPath, subprotocols, allowExtensions, 65536, false); 
     // TODO Auto-generated constructor stub 
    } 

    public CustomWebSocketServerProtocolHandler(String websocketPath, 
      String subprotocols, boolean allowExtensions, int maxFrameSize, 
      boolean allowMaskMismatch) { 
     super(websocketPath, subprotocols, allowExtensions, maxFrameSize, 
       allowMaskMismatch); 

     this.websocketPath = websocketPath; 
     this.subprotocols = subprotocols; 
     this.allowExtensions = allowExtensions; 
     maxFramePayloadLength = maxFrameSize; 
     this.allowMaskMismatch = allowMaskMismatch; 
    } 

    @Override 
    public void handlerAdded(ChannelHandlerContext ctx) { 
     ChannelPipeline cp = ctx.pipeline(); 
     if (cp.get(CustomWebSocketServerProtocolHandshakeHandler.class) == null) { 
      // Add the WebSocketHandshakeHandler before this one. 
      ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandshakeHandler.class.getName(), 
         new CustomWebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, 
           allowExtensions, maxFramePayloadLength, allowMaskMismatch)); 
     } 
     if (cp.get(Utf8FrameValidator.class) == null) { 
      // Add the UFT8 checking before this one. 
      ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(), 
        new Utf8FrameValidator()); 
     } 
    } 


} 

將它放到管道

pipeline.addLast(new HttpServerCodec()); 
    pipeline.addLast(new HttpObjectAggregator(65536)); 
    pipeline.addLast(new WebSocketServerCompressionHandler()); 
    pipeline.addLast(new CustomWebSocketServerProtocolHandler(WEBSOCKET_PATH, "*", true));