The One Billion Row Challenge in Java

背景

Github 项目

The One Billion Row Challenge -- A fun exploration of how quickly 1B rows from a text file can be aggregated with Java

1B rows 约13GB

环境要求

  • 硬件: 32core 128G mem
  • JDK21
  • Linux

实现

迭代1

  • 使用 MappedByteBuffer 减少内核和用户空间的buffer copy
  • 利用Runtime.getRuntime().availableProcessors() 划分chunk
  • HashMap 作为气温的统计k,v=> byte[],MeasurementAggregator::count min sum max

效果不理想

使用Flame graph 分析

  • hashmap::contains(key) 性能非常差
  • key::hashCode and key::equals 性能糟糕

迭代2

优化

static class CalculateKey {
        private final byte[] bytes;
        private final int length;
        private final int hash;

        public CalculateKey(byte[] bytes, int length, int hash) {
            this.bytes = bytes;
            this.length = length;
            this.hash = hash;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            CalculateKey that = (CalculateKey) o;
            return length == that.length && hash == that.hash && Arrays.equals(bytes, that.bytes);
        }

        @Override
        public int hashCode() {
            int result = Objects.hash(length, hash);
            result = 31 * result + Arrays.hashCode(bytes);
            return result;
        }
    }
  • 移除hashmap::contains

  • 优化hashcode:在读取byte时,即时生成key的hashcode,在key 的put操作时不在计算

    • hashCode = 31 * hashCode + positionByte;
 @Override
public int hashCode() {
    return hash;
}

效果有提升1亿rows 本地环境(8core16G)2s多

profiler 分析

  • MappedByteBuffer.get()
  • hashmap.get
  • new byte[] == key字节数组

这三个方法性能损耗最大,可以着手进行改进

迭代3

替换hashMap为自定义实现

内存布局:

  • ByteBuffer data;//original data

  • int[] index;mem的offset

  • long[] mem;每一个station的内存结构

    • mem[0] = byteOffset + mem[1]=keyLength + mem[2]=measurement + mem[3]=key's hashcode

      • byteOffset 是data结构中key的start位置,结合keyLength,可以求得Key的byte[]
      • key's hashcode ,结合keyLength,决定两个key equals
    • mem[4] ~ mem[7] = count + sum + min + max 统计数据

    • 下一个station数据从mem[8]开始

    • 一个station 8个long

    • 以空间换时间

    • data 是引用,没有内存空间再分配,通过offset找到key

static class SimpleHashMap {
        ByteBuffer data;//original data
        private static final int STATION_SIZE = 8;//memory layout per station
        private static final int CAPACITY = 1024 * 64;// station size
        private static final int INDEX_MASK = CAPACITY - 1;
        // index[] -> mem[]'s offset
        int[] index;
        // mem[0] = byteOffset + mem[1]=keyLength + mem[2]=measurement + mem[3]=key's hashcode
        // mem[4] ~ mem[7] = count + sum + min + max
        long[] mem;

        public SimpleHashMap(ByteBuffer chunk) {
            index = new int[CAPACITY];
            mem = new long[CAPACITY * STATION_SIZE + 1];
            data = chunk;
        }

        public void put(int byteOffset, int hash, long value, int keyLength, int memOffset) {
            int bucket = hash & INDEX_MASK;
            for (;; bucket = (bucket + 1) & INDEX_MASK) {
                int offset = this.index[bucket];
                if (offset == 0) {
                    this.index[bucket] = memOffset;
                    mem[memOffset] = byteOffset;
                    mem[memOffset + 1] = keyLength;
                    mem[memOffset + 2] = value;
                    mem[memOffset + 3] = hash;
                    mem[memOffset + 4] = 1;// count
                    mem[memOffset + 5] = value;// sum
                    mem[memOffset + 6] = value;// min
                    mem[memOffset + 7] = value;// max
                    break;
                }
                else {
                    int prevKeyLength = (int) mem[offset + 1];
                    int prevHash = (int) mem[offset + 3];
                    if (prevHash == hash && prevKeyLength == keyLength) {
                        mem[offset + 4] += 1;// count
                        mem[offset + 5] += value;// sum
                        mem[offset + 6] = Math.min(value, mem[offset + 6]);// min
                        mem[offset + 7] = Math.max(value, mem[offset + 7]);// max
                        break;
                    }
                }
            }
        }

        public int get(int hash) {
            int bucket = hash & INDEX_MASK;
            bucket = (bucket + 1) & INDEX_MASK;
            return index[bucket];
        }

        void merge(Map<String, MeasurementAggregator> target) {
            this.data.flip();
            for (int i = 0; i < CAPACITY; i++) {
                int offset = this.index[i];
                if (offset == 0) {
                    continue;
                }
                int start = (int) mem[offset];
                int keyLen = (int) mem[offset + 1];

                byte[] keyByte = new byte[keyLen];
                data.get(start, keyByte);
                String key = new String(keyByte, StandardCharsets.UTF_8);
                target.compute(key, (k, v) -> {
                    if (v == null) {
                        v = new MeasurementAggregator();
                    }
                    v.min = Math.min(v.min, mem[offset + 6]);
                    v.max = Math.max(v.max, mem[offset + 7]);
                    v.sum += mem[offset + 5];
                    v.count += mem[offset + 4];
                    return v;
                });
            }
        }
    }

EC2 c5a.8xlarge 实例 32core 64GB

[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_gumingcn.sh
real    0m2.068s
user    0m53.515s
sys     0m0.910s

//其他实现的time结果
[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_gonix.sh
real    0m1.023s
user    0m0.000s
sys     0m0.002s

[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_merykitty.sh 
real    0m1.502s
user    0m32.113s
sys     0m1.183s

[ec2-user@ip-172-31-33-122 1brc]$ time ./calculate_average_thomaswue.sh 
real    0m1.440s
user    0m34.203s
sys     0m0.912s

迭代4

如何优化 MappedByteBuffer.get()

上面gonix作者给出了一个方案

简单说,如果通过技巧找到';'和'\n',切分 station;measurement\n

ByteBuffer.getLong() 一次8个字节,可以减少循环读取的耗时

但如何找到';'呢?

Ascii code ';' - 91 十六进制 3B

private static long valueSepMark(long keyLong) {
        // Seen this trick used in multiple other solutions.
        // Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
        long match = keyLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';'
        match = (match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L);
        return match;
    }

如何快速找到'\n'呢?

  • 数字1-9的ascii 二进制表示如6, 0011 0110 第4位是1
  • 换行符 ascii 二进制表示 0000 1010 第4位是0

这样便可以通过位运算找到

private static int decimalSepMark(long value) {
            // Seen this trick used in multiple other solutions.
            // Looks like the original author is @merykitty.

            // The 4th binary digit of the ascii of a digit is 1 while
            // that of the '.' is 0. This finds the decimal separator
            // The value can be 12, 20, 28
            return Long.numberOfTrailingZeros(~value & 0x10101000);
        }

详细见:read more

位运算的奇特技巧

同时优化内存布局

  • key len 使用int即满足(<100)
  • 单chunk offset int类型满足
  • 使用一个long 既可以解决上面两个类型
  • 位运算计算measurement double值
private static long tailAndLen(int tailBits, long keyLong, long keyLen) {
            long tailMask = ~(-1L << tailBits);
            long tail = keyLong & tailMask;
            return (tail << 8) | ((keyLen >> 3) & 0xFF);
        }

private static int decimalValue(int decimalSepMark, long value) {
            // Seen this trick used in multiple other solutions.
            // Looks like the original author is @merykitty.

            int shift = 28 - decimalSepMark;
            // signed is -1 if negative, 0 otherwise
            long signed = (~value << 59) >> 63;
            long designMask = ~(signed & 0xFF);
            // Align the number to a specific position and transform the ascii code
            // to actual digit value in each byte
            long digits = ((value & designMask) << shift) & 0x0F000F0F00L;

            // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
            // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
            // 0x000000UU00TTHH00 +
            // 0x00UU00TTHH000000 * 10 +
            // 0xUU00TTHH00000000 * 100
            // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
            // This results in our value lies in the bit 32 to 41 of this product
            // That was close :)
            long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
            return (int) ((absValue ^ signed) - signed);
        }
  • gonix的mem结构,对于key len<=8的station 4个index即可解决一个station,同样我的实现需要8个
  • measurement 的计算同样通过位运算,非常高效

缺点是大量的位运算代码不易理解

总结

本文的实现:https://github.com/guming/1brc

  • 没有使用Unsafe和 MemorySegment/ByteVector(Flink底层使用)再进行测试
  • 没有优化jvm ops

原因,个人认为最终的优化还是内存布局和位运算简化byte读取 是关键所在,使用何种工具类并不是核心所在(当然性能会提升)