Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ protected ArrowBuf doCompress(BufferAllocator allocator, ArrowBuf uncompressedBu
long bytesWritten =
Zstd.compressUnsafe(
compressedBuffer.memoryAddress() + CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH,
dstSize,
maxSize,
/*src*/ uncompressedBuffer.memoryAddress(),
/* srcSize= */ uncompressedBuffer.writerIndex(),
/* level= */ this.compressionLevel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand All @@ -53,12 +55,15 @@
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down Expand Up @@ -347,6 +352,253 @@ void testUnloadCompressed(CompressionUtil.CodecType codec) {
});
}

/**
* Test multi-batch streaming with ZSTD compression, wide schema, VectorSchemaRoot reuse, and
* all-null columns. This reproduces the scenario from GH-1116 where the 8-byte
* uncompressed-length prefix of a compressed buffer could be incorrectly written as 0.
*/
@Test
void testMultiBatchZstdStreamWithWideSchemaAndAllNulls() throws Exception {
final int fieldCount = 100;
final int batchCount = 10;
final int rowsPerBatch = 500;

// Build a wide schema: mix of int, timestamp, and varchar fields
List<Field> fields = new ArrayList<>();
for (int i = 0; i < fieldCount; i++) {
switch (i % 3) {
case 0:
fields.add(Field.nullable("int_" + i, new ArrowType.Int(32, true)));
break;
case 1:
fields.add(
Field.nullable("ts_" + i, new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)));
break;
case 2:
fields.add(Field.nullable("str_" + i, ArrowType.Utf8.INSTANCE));
break;
default:
break;
}
}
Schema schema = new Schema(fields);

ByteArrayOutputStream out = new ByteArrayOutputStream();
try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
ArrowStreamWriter writer =
new ArrowStreamWriter(
root,
new DictionaryProvider.MapDictionaryProvider(),
Channels.newChannel(out),
IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE,
CompressionUtil.CodecType.ZSTD)) {
writer.start();

for (int batch = 0; batch < batchCount; batch++) {
// Clear and reallocate — mimics the reporter's reuse pattern
root.clear();
for (FieldVector vector : root.getFieldVectors()) {
vector.allocateNew();
}
root.setRowCount(rowsPerBatch);

for (int col = 0; col < fieldCount; col++) {
FieldVector vector = root.getVector(col);
// Make some batches have all-null columns for certain fields
boolean allNull = (batch % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch
switch (col % 3) {
case 0:
{
IntVector iv = (IntVector) vector;
for (int row = 0; row < rowsPerBatch; row++) {
if (allNull || row % 7 == 0) {
iv.setNull(row);
} else {
iv.setSafe(row, batch * rowsPerBatch + row);
}
}
break;
}
case 1:
{
TimeStampMilliVector tv = (TimeStampMilliVector) vector;
for (int row = 0; row < rowsPerBatch; row++) {
if (allNull || row % 5 == 0) {
tv.setNull(row);
} else {
tv.setSafe(row, 1_700_000_000_000L + (long) batch * rowsPerBatch + row);
}
}
break;
}
case 2:
{
VarCharVector sv = (VarCharVector) vector;
for (int row = 0; row < rowsPerBatch; row++) {
if (allNull || row % 9 == 0) {
sv.setNull(row);
} else {
sv.setSafe(row, ("val_" + batch + "_" + row).getBytes(StandardCharsets.UTF_8));
}
}
break;
}
default:
break;
}
vector.setValueCount(rowsPerBatch);
}

writer.writeBatch();
}
writer.end();
}

// Read back and verify all batches round-trip correctly
try (ArrowStreamReader reader =
new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator,
CommonsCompressionFactory.INSTANCE)) {
int batchesRead = 0;
while (reader.loadNextBatch()) {
VectorSchemaRoot readRoot = reader.getVectorSchemaRoot();
assertEquals(rowsPerBatch, readRoot.getRowCount());
assertEquals(fieldCount, readRoot.getFieldVectors().size());

// Verify data values, null patterns, and all-null columns
for (int col = 0; col < fieldCount; col++) {
FieldVector vector = readRoot.getVector(col);
boolean allNull =
(batchesRead % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch
if (allNull) {
// The key scenario: all-null columns must survive compression round-trip
assertEquals(
rowsPerBatch,
vector.getNullCount(),
"All-null column col=" + col + " batch=" + batchesRead);
}
for (int row = 0; row < rowsPerBatch; row++) {
switch (col % 3) {
case 0:
{
IntVector iv = (IntVector) vector;
if (allNull || row % 7 == 0) {
assertTrue(
iv.isNull(row),
"Expected null at col=" + col + " row=" + row + " batch=" + batchesRead);
} else {
assertEquals(
batchesRead * rowsPerBatch + row,
iv.get(row),
"Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead);
}
break;
}
case 1:
{
TimeStampMilliVector tv = (TimeStampMilliVector) vector;
if (allNull || row % 5 == 0) {
assertTrue(
tv.isNull(row),
"Expected null at col=" + col + " row=" + row + " batch=" + batchesRead);
} else {
assertEquals(
1_700_000_000_000L + (long) batchesRead * rowsPerBatch + row,
tv.get(row),
"Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead);
}
break;
}
case 2:
{
VarCharVector sv = (VarCharVector) vector;
if (allNull || row % 9 == 0) {
assertTrue(
sv.isNull(row),
"Expected null at col=" + col + " row=" + row + " batch=" + batchesRead);
} else {
assertArrayEquals(
("val_" + batchesRead + "_" + row).getBytes(StandardCharsets.UTF_8),
sv.get(row),
"Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead);
}
break;
}
default:
break;
}
}
}
batchesRead++;
}
assertEquals(batchCount, batchesRead);
}
}

/**
* Test that an all-null fixed-width vector compresses and decompresses correctly. The data buffer
* for such a vector contains all zeros but has a non-zero writerIndex (valueCount * typeWidth).
* The compressed buffer's uncompressed-length prefix must reflect this non-zero size.
*/
@Test
void testAllNullFixedWidthVectorZstdRoundTrip() throws Exception {
final int rowCount = 3469; // same count as the reported issue
final CompressionCodec codec = new ZstdCompressionCodec();

try (TimeStampMilliVector origVec =
new TimeStampMilliVector(
"ts",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
allocator)) {
origVec.allocateNew(rowCount);
// Set all values to null
for (int i = 0; i < rowCount; i++) {
origVec.setNull(i);
}
origVec.setValueCount(rowCount);

assertEquals(rowCount, origVec.getNullCount());

// Compress and decompress each buffer
List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
assertEquals(2, origBuffers.size());

// The data buffer (index 1) should have non-zero writerIndex even though all values are null
ArrowBuf dataBuffer = origBuffers.get(1);
long expectedDataSize = (long) rowCount * 8; // TimestampMilli = 8 bytes per value
assertEquals(expectedDataSize, dataBuffer.writerIndex());

// Retain buffers before compressing since compress() closes the input buffer.
// This mirrors what VectorUnloader.appendNodes() does.
for (ArrowBuf buf : origBuffers) {
buf.getReferenceManager().retain();
}
List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

assertEquals(2, decompressedBuffers.size());

// The decompressed data buffer should have the same writerIndex as the original
assertEquals(expectedDataSize, decompressedBuffers.get(1).writerIndex());

// Load into a new vector and verify
try (TimeStampMilliVector newVec =
new TimeStampMilliVector(
"ts_new",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
allocator)) {
newVec.loadFieldBuffers(new ArrowFieldNode(rowCount, rowCount), decompressedBuffers);
assertEquals(rowCount, newVec.getValueCount());
for (int i = 0; i < rowCount; i++) {
assertTrue(newVec.isNull(i));
}
}
AutoCloseables.close(decompressedBuffers);
}
}

void withRoot(
CompressionUtil.CodecType codec,
BiConsumer<CompressionCodec.Factory, VectorSchemaRoot> testBody) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ public abstract class AbstractCompressionCodec implements CompressionCodec {

@Override
public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) {
if (uncompressedBuffer.writerIndex() == 0L) {
// Capture the uncompressed length once upfront to avoid any inconsistency from
// re-reading writerIndex() at different points. Since the uncompressedBuffer may be
// a shared reference to a vector's internal buffer, reading writerIndex() only once
// ensures the same value is used for the empty-buffer check, compression, size
// comparison, and the 8-byte uncompressed-length prefix.
long uncompressedLength = uncompressedBuffer.writerIndex();

if (uncompressedLength == 0L) {
// shortcut for empty buffer
ArrowBuf compressedBuffer = allocator.buffer(CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH);
compressedBuffer.setLong(0, 0);
Expand All @@ -41,7 +48,6 @@ public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer)
ArrowBuf compressedBuffer = doCompress(allocator, uncompressedBuffer);
long compressedLength =
compressedBuffer.writerIndex() - CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH;
long uncompressedLength = uncompressedBuffer.writerIndex();

if (compressedLength > uncompressedLength) {
// compressed buffer is larger, send the raw buffer
Expand Down
Loading