diff --git a/protocol/src/main/java/mc/protocol/ChannelContext.java b/protocol/src/main/java/mc/protocol/ChannelContext.java index ba85b49..51d7a89 100644 --- a/protocol/src/main/java/mc/protocol/ChannelContext.java +++ b/protocol/src/main/java/mc/protocol/ChannelContext.java @@ -3,10 +3,10 @@ package mc.protocol; import io.netty.channel.ChannelHandlerContext; import lombok.Getter; import lombok.RequiredArgsConstructor; -import mc.protocol.packets.Packet; +import mc.protocol.packets.ClientSidePacket; @RequiredArgsConstructor -public class ChannelContext

{ +public class ChannelContext

{ @Getter private final ChannelHandlerContext ctx; diff --git a/protocol/src/main/java/mc/protocol/NettyServer.java b/protocol/src/main/java/mc/protocol/NettyServer.java index b2f20fc..2c92645 100644 --- a/protocol/src/main/java/mc/protocol/NettyServer.java +++ b/protocol/src/main/java/mc/protocol/NettyServer.java @@ -5,19 +5,12 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import mc.protocol.di.DaggerProtocolComponent; import mc.protocol.di.ProtocolComponent; -import mc.protocol.packets.Packet; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Sinks; -import java.util.Map; - -@SuppressWarnings("rawtypes") @Slf4j @RequiredArgsConstructor public class NettyServer { private final ServerBootstrap serverBootstrap; - private final Map, Sinks.Many> observedMap; public void bind(String host, int port) { log.info("Network starting: {}:{}", host, port); @@ -31,11 +24,6 @@ public class NettyServer { } } - @SuppressWarnings("unchecked") - public

Flux> packetFlux(Class

packetClass) { - return observedMap.get(packetClass).asFlux().map(ChannelContext.class::cast); - } - public static NettyServer createServer() { ProtocolComponent component = DaggerProtocolComponent.create(); return component.getNettyServer(); diff --git a/protocol/src/main/java/mc/protocol/PacketInboundHandler.java b/protocol/src/main/java/mc/protocol/PacketInboundHandler.java index 68c3eaa..a99fb4d 100644 --- a/protocol/src/main/java/mc/protocol/PacketInboundHandler.java +++ b/protocol/src/main/java/mc/protocol/PacketInboundHandler.java @@ -3,21 +3,20 @@ package mc.protocol; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import lombok.RequiredArgsConstructor; -import mc.protocol.packets.Packet; +import mc.protocol.packets.ClientSidePacket; import reactor.core.publisher.Sinks; -import java.util.Map; - -@SuppressWarnings("rawtypes") @RequiredArgsConstructor -public class PacketInboundHandler extends SimpleChannelInboundHandler { - - private final Map, Sinks.Many> observedMap; +public class PacketInboundHandler extends SimpleChannelInboundHandler { + @SuppressWarnings("rawtypes") @Override - protected void channelRead0(ChannelHandlerContext ctx, Packet packet) { - if (observedMap.containsKey(packet.getClass())) { - observedMap.get(packet.getClass()).tryEmitNext(new ChannelContext<>(ctx, packet)); + protected void channelRead0(ChannelHandlerContext ctx, ClientSidePacket packet) { + Sinks.Many packetSinks = ctx.channel().attr(NetworkAttributes.STATE) + .get().getPacketSinks(packet.getClass()); + + if (packetSinks != null) { + packetSinks.tryEmitNext(new ChannelContext<>(ctx, packet)); } } } diff --git a/protocol/src/main/java/mc/protocol/State.java b/protocol/src/main/java/mc/protocol/State.java index 347c30c..f284800 100644 --- a/protocol/src/main/java/mc/protocol/State.java +++ b/protocol/src/main/java/mc/protocol/State.java @@ -8,9 +8,12 @@ import mc.protocol.packets.PingPacket; import mc.protocol.packets.ServerSidePacket; import mc.protocol.packets.client.*; import mc.protocol.packets.server.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; import javax.annotation.Nullable; import java.util.Collections; +import java.util.HashMap; import java.util.Map; @RequiredArgsConstructor @@ -75,10 +78,12 @@ public enum State { @Getter private final int id; - @Getter private final Map> clientSidePackets; private final Map, Integer> serverSidePackets; + @SuppressWarnings("rawtypes") + private final Map, Sinks.Many> observedMap = new HashMap<>(); + State(int id, Map> clientSidePackets) { this.id = id; this.clientSidePackets = clientSidePackets; @@ -94,4 +99,16 @@ public enum State { public Integer getServerSidePacketId(Class clazz) { return serverSidePackets == null ? null : serverSidePackets.get(clazz); } + + + @SuppressWarnings("rawtypes") + public

Sinks.Many getPacketSinks(Class

packetClass) { + return observedMap.get(packetClass); + } + + @SuppressWarnings("unchecked") + public

Flux> packetFlux(Class

packetClass) { + return observedMap.computeIfAbsent(packetClass, aClass -> Sinks.many().multicast().directBestEffort()) + .asFlux().map(ChannelContext.class::cast); + } } diff --git a/protocol/src/main/java/mc/protocol/di/ProtocolModule.java b/protocol/src/main/java/mc/protocol/di/ProtocolModule.java index 167d27f..3e8a8df 100644 --- a/protocol/src/main/java/mc/protocol/di/ProtocolModule.java +++ b/protocol/src/main/java/mc/protocol/di/ProtocolModule.java @@ -11,31 +11,23 @@ import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; -import mc.protocol.ChannelContext; import mc.protocol.NettyServer; import mc.protocol.PacketInboundHandler; -import mc.protocol.State; import mc.protocol.io.codec.ProtocolDecoder; import mc.protocol.io.codec.ProtocolEncoder; import mc.protocol.io.codec.ProtocolSplitter; -import mc.protocol.packets.Packet; -import reactor.core.publisher.Sinks; import javax.annotation.Nonnull; import javax.inject.Provider; import java.util.LinkedHashMap; import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; @Module public class ProtocolModule { - @SuppressWarnings("rawtypes") @Provides - NettyServer provideServer(ServerBootstrap serverBootstrap, - Map, Sinks.Many> observedMap) { - return new NettyServer(serverBootstrap, observedMap); + NettyServer provideServer(ServerBootstrap serverBootstrap) { + return new NettyServer(serverBootstrap); } @Provides @@ -62,28 +54,16 @@ public class ProtocolModule { }; } - @SuppressWarnings("rawtypes") @Provides - Map provideChannelHandlerMap( - Map, Sinks.Many> observedMap) { - + Map provideChannelHandlerMap() { Map map = new LinkedHashMap<>(); map.put("packet_splitter", new ProtocolSplitter()); map.put("logger", new LoggingHandler(LogLevel.DEBUG)); map.put("packet_decoder", new ProtocolDecoder(true)); map.put("packet_encoder", new ProtocolEncoder()); - map.put("packet_handler", new PacketInboundHandler(observedMap)); + map.put("packet_handler", new PacketInboundHandler()); return map; } - - @SuppressWarnings("rawtypes") - @Provides - @ServerScope - Map, Sinks.Many> provideObservedMap() { - return Stream.of(State.values()) - .flatMap(state -> state.getClientSidePackets().values().stream()) - .collect(Collectors.toMap(packetClass -> packetClass, v -> Sinks.many().multicast().directBestEffort())); - } } diff --git a/server/src/main/java/mc/server/Main.java b/server/src/main/java/mc/server/Main.java index b3120e2..e8bb6aa 100644 --- a/server/src/main/java/mc/server/Main.java +++ b/server/src/main/java/mc/server/Main.java @@ -8,6 +8,7 @@ import joptsimple.OptionSet; import joptsimple.util.PathConverter; import lombok.extern.slf4j.Slf4j; import mc.protocol.NettyServer; +import mc.protocol.State; import mc.protocol.packets.PingPacket; import mc.protocol.packets.client.HandshakePacket; import mc.protocol.packets.client.LoginStartPacket; @@ -48,10 +49,10 @@ public class Main { NettyServer server = NettyServer.createServer(); PacketHandler packetHandler = serverComponent.getPacketHandler(); - server.packetFlux(HandshakePacket.class).subscribe(packetHandler::onHandshake); - server.packetFlux(PingPacket.class).subscribe(packetHandler::onKeepAlive); - server.packetFlux(StatusServerRequestPacket.class).subscribe(packetHandler::onServerStatus); - server.packetFlux(LoginStartPacket.class).subscribe(packetHandler::onLoginStart); + State.HANDSHAKING.packetFlux(HandshakePacket.class).subscribe(packetHandler::onHandshake); + State.STATUS.packetFlux(PingPacket.class).subscribe(packetHandler::onKeepAlive); + State.STATUS.packetFlux(StatusServerRequestPacket.class).subscribe(packetHandler::onServerStatus); + State.LOGIN.packetFlux(LoginStartPacket.class).subscribe(packetHandler::onLoginStart); server.bind(config.server().host(), config.server().port()); }