diff --git a/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverter.java b/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverter.java index 565c2bde9..cf86e9ab0 100644 --- a/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverter.java +++ b/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverter.java @@ -23,6 +23,7 @@ import io.netty.handler.codec.mqtt.MqttProperties; import io.netty.handler.codec.mqtt.MqttPublishMessage; import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.FastThreadLocal; import io.streamnative.pulsar.handlers.mqtt.broker.MQTTServerConfiguration; import io.streamnative.pulsar.handlers.mqtt.common.mqtt5.PacketIdGenerator; @@ -203,12 +204,13 @@ public static List toMqttMessages(String topicName, Entry en } catch (IOException e) { log.error("Error decoding batch for message {}. Whole batch will be included in output", entry.getPosition(), e); + response.forEach(ReferenceCountUtil::safeRelease); return Collections.emptyList(); } } else { return Lists.newArrayList(MessageBuilder.publish() .messageId(packetIdGenerator.nextPacketId()) - .payload(metadataAndPayload) + .payload(metadataAndPayload.retainedSlice()) .topicName(topicName) .qos(qos) .properties(properties) diff --git a/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/consumer/MQTTConsumer.java b/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/consumer/MQTTConsumer.java index 787674622..a7643d89d 100644 --- a/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/consumer/MQTTConsumer.java +++ b/mqtt-broker/src/main/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/consumer/MQTTConsumer.java @@ -18,6 +18,7 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.codec.mqtt.MqttPublishMessage; import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; import io.streamnative.pulsar.handlers.mqtt.broker.channel.MQTTServerCnx; import io.streamnative.pulsar.handlers.mqtt.broker.impl.PulsarMessageConverter; @@ -27,6 +28,7 @@ import io.streamnative.pulsar.handlers.mqtt.common.mqtt5.PacketIdGenerator; import io.streamnative.pulsar.handlers.mqtt.common.mqtt5.restrictions.ClientRestrictions; import io.streamnative.pulsar.handlers.mqtt.common.utils.PulsarTopicUtils; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -97,78 +99,117 @@ public Future sendMessages(List entries, EntryBatchSizes long sentMessages = 0; long sentBytes = 0; MESSAGE_PERMITS_UPDATER.addAndGet(this, -totalMessages); - for (Entry entry : entries) { - String toConsumerTopicName = PulsarTopicUtils.getToConsumerTopicName(mqttTopicName, pulsarTopicName); - List messages = PulsarMessageConverter.toMqttMessages(toConsumerTopicName, entry, - packetIdGenerator, qos); - if (MqttQoS.AT_MOST_ONCE != qos) { - final boolean isBatch = messages.size() > 1; - if (isBatch) { - for (int i = 0; i < messages.size(); i++) { - int packetId = messages.get(i).variableHeader().packetId(); - OutstandingPacket outstandingPacket = new OutstandingPacket(this, packetId, entry.getLedgerId(), - entry.getEntryId(), i, messages.size()); - outstandingPacketContainer.add(outstandingPacket); + List entriesToRelease = new ArrayList<>(entries.size()); + try { + for (Entry entry : entries) { + if (entry == null) { + continue; + } + entriesToRelease.add(entry); + String toConsumerTopicName = PulsarTopicUtils.getToConsumerTopicName(mqttTopicName, pulsarTopicName); + List messages = PulsarMessageConverter.toMqttMessages(toConsumerTopicName, entry, + packetIdGenerator, qos); + if (messages.isEmpty()) { + continue; + } + if (MqttQoS.AT_MOST_ONCE != qos) { + final boolean isBatch = messages.size() > 1; + if (isBatch) { + for (int i = 0; i < messages.size(); i++) { + int packetId = messages.get(i).variableHeader().packetId(); + OutstandingPacket outstandingPacket = new OutstandingPacket(this, packetId, + entry.getLedgerId(), entry.getEntryId(), i, messages.size()); + outstandingPacketContainer.add(outstandingPacket); + } + } else { + // Because batch msg is sent from Pulsar client, + // so only individual msg may have mqtt-5 properties. + MqttPublishMessage firstMessage = messages.get(0); + long expiryInterval = getMessageExpiryInterval(firstMessage); + boolean addToOutstandingPacketContainer = expiryInterval >= 0; + if (expiryInterval < 0) { + log.warn("mqtt msg has expired : {}", firstMessage); + ReferenceCountUtil.safeRelease(messages.remove(0)); + getSubscription().acknowledgeMessage( + Collections.singletonList(entry.getPosition()), + CommandAck.AckType.Individual, Collections.emptyMap()); + } + if (addToOutstandingPacketContainer) { + OutstandingPacket outstandingPacket = new OutstandingPacket(this, + messages.get(0).variableHeader().packetId(), entry.getLedgerId(), + entry.getEntryId()); + outstandingPacketContainer.add(outstandingPacket); + } } - } else { - // Because batch msg is sent from Pulsar client, so only individual msg may have mqtt-5 properties. - MqttPublishMessage firstMessage = messages.get(0); - long expiryInterval = getMessageExpiryInterval(firstMessage); - boolean addToOutstandingPacketContainer = expiryInterval >= 0; - if (expiryInterval < 0) { - log.warn("mqtt msg has expired : {}", firstMessage); - messages.remove(0); - getSubscription().acknowledgeMessage( - Collections.singletonList(entry.getPosition()), + } + for (MqttPublishMessage msg : messages) { + if (log.isDebugEnabled()) { + log.debug("[{}] [{}] [{}] Send MQTT message {} to subscriber", pulsarTopicName, + mqttTopicName, super.getSubscription().getName(), msg); + } + final int readableBytes = msg.payload().readableBytes(); + metricsCollector.addReceived(readableBytes); + if (clientRestrictions.exceedMaximumPacketSize(readableBytes)) { + log.warn("discard msg {}, because it exceeds maximum packet size : {}, msg size {}", msg, + clientRestrictions.getMaximumPacketSize(), readableBytes); + getSubscription().acknowledgeMessage(Collections.singletonList(entry.getPosition()), CommandAck.AckType.Individual, Collections.emptyMap()); + ReferenceCountUtil.safeRelease(msg); + continue; } - if (addToOutstandingPacketContainer) { - OutstandingPacket outstandingPacket = new OutstandingPacket(this, - messages.get(0).variableHeader().packetId(), entry.getLedgerId(), entry.getEntryId()); - outstandingPacketContainer.add(outstandingPacket); + sentMessages++; + sentBytes += readableBytes; + boolean written = false; + try { + cnx.ctx().channel().write(new MqttAdapterMessage(connection.getClientId(), msg, + connection.isFromProxy())); + written = true; + } finally { + if (!written) { + ReferenceCountUtil.safeRelease(msg); + } } } } - for (MqttPublishMessage msg : messages) { - if (log.isDebugEnabled()) { - log.debug("[{}] [{}] [{}] Send MQTT message {} to subscriber", pulsarTopicName, - mqttTopicName, super.getSubscription().getName(), msg); - } - final int readableBytes = msg.payload().readableBytes(); - metricsCollector.addReceived(readableBytes); - if (clientRestrictions.exceedMaximumPacketSize(readableBytes)) { - log.warn("discard msg {}, because it exceeds maximum packet size : {}, msg size {}", msg, - clientRestrictions.getMaximumPacketSize(), readableBytes); - getSubscription().acknowledgeMessage(Collections.singletonList(entry.getPosition()), - CommandAck.AckType.Individual, Collections.emptyMap()); - continue; + if (MqttQoS.AT_MOST_ONCE == qos) { + incrementPermits(totalMessages); + if (entries.size() > 0) { + getSubscription().acknowledgeMessage( + Collections.singletonList(entries.get(entries.size() - 1).getPosition()), + CommandAck.AckType.Cumulative, Collections.emptyMap()); } - sentMessages++; - sentBytes += readableBytes; - cnx.ctx().channel().write(new MqttAdapterMessage(connection.getClientId(), msg, - connection.isFromProxy())); - } - } - if (MqttQoS.AT_MOST_ONCE == qos) { - incrementPermits(totalMessages); - if (entries.size() > 0) { - getSubscription().acknowledgeMessage( - Collections.singletonList(entries.get(entries.size() - 1).getPosition()), - CommandAck.AckType.Cumulative, Collections.emptyMap()); } + final long deliveredMessages = sentMessages; + final long deliveredBytes = sentBytes; + cnx.ctx().channel().writeAndFlush(Unpooled.EMPTY_BUFFER, promise); + promise.addListener(future -> { + try { + if (future.isSuccess() && (deliveredMessages > 0 || deliveredBytes > 0)) { + recordStatsUpdate(System.currentTimeMillis(), 0, System.currentTimeMillis(), + deliveredMessages, deliveredBytes); + } + } finally { + entriesToRelease.forEach(Entry::release); + recycleBatchObjects(batchSizes, batchIndexesAcks); + } + }); + } catch (Throwable t) { + entriesToRelease.forEach(Entry::release); + recycleBatchObjects(batchSizes, batchIndexesAcks); + promise.tryFailure(t); } - final long deliveredMessages = sentMessages; - final long deliveredBytes = sentBytes; - cnx.ctx().channel().writeAndFlush(Unpooled.EMPTY_BUFFER, promise); - promise.addListener(future -> { - if (future.isSuccess() && (deliveredMessages > 0 || deliveredBytes > 0)) { - recordStatsUpdate(System.currentTimeMillis(), 0, System.currentTimeMillis(), - deliveredMessages, deliveredBytes); - } - }); return promise; } + private static void recycleBatchObjects(EntryBatchSizes batchSizes, EntryBatchIndexesAcks batchIndexesAcks) { + if (batchSizes != null) { + batchSizes.recyle(); + } + if (batchIndexesAcks != null) { + batchIndexesAcks.recycle(); + } + } + @Override public boolean equals(Object o) { return super.equals(o); diff --git a/mqtt-broker/src/test/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverterTest.java b/mqtt-broker/src/test/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverterTest.java new file mode 100644 index 000000000..83babfbda --- /dev/null +++ b/mqtt-broker/src/test/java/io/streamnative/pulsar/handlers/mqtt/broker/impl/PulsarMessageConverterTest.java @@ -0,0 +1,58 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.pulsar.handlers.mqtt.broker.impl; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.mqtt.MqttPublishMessage; +import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.util.ReferenceCountUtil; +import io.streamnative.pulsar.handlers.mqtt.common.mqtt5.PacketIdGenerator; +import java.util.List; +import org.apache.bookkeeper.mledger.PositionFactory; +import org.apache.bookkeeper.mledger.impl.EntryImpl; +import org.apache.pulsar.common.api.proto.MessageMetadata; +import org.apache.pulsar.common.protocol.Commands; +import org.testng.annotations.Test; + +public class PulsarMessageConverterTest { + + @Test + public void testMqttPayloadReleaseDoesNotReleaseEntry() { + ByteBuf payload = Unpooled.copiedBuffer("payload", UTF_8); + MessageMetadata metadata = new MessageMetadata() + .setProducerName("producer") + .setSequenceId(1) + .setPublishTime(System.currentTimeMillis()); + ByteBuf metadataAndPayload = Commands.serializeMetadataAndPayload(Commands.ChecksumType.None, + metadata, payload); + EntryImpl entry = EntryImpl.create(PositionFactory.create(1, 1), metadataAndPayload, 1); + metadataAndPayload.release(); + payload.release(); + + List messages = PulsarMessageConverter.toMqttMessages("mqtt/topic", entry, + PacketIdGenerator.newNonZeroGenerator(), MqttQoS.AT_MOST_ONCE); + + assertEquals(messages.size(), 1); + MqttPublishMessage message = messages.get(0); + assertEquals(message.payload().toString(UTF_8), "payload"); + + ReferenceCountUtil.release(message); + assertEquals(entry.refCnt(), 1); + + entry.release(); + } +}