增强 Stream 接口的 distinct 方法的一些思考

2021-04-12 14:35:26 浏览数 (1)

遇到的问题

Java 8 开始引入了 Stream, 其中的 api 一直在不断的优化更新完善,Java 9 中更是引入了 ofNullable 还有 takeWhiledropWhile 这两个关键 api。有时候,我们想对 Stream 中的对象进行排重,默认的可以用 distinct 这个 api,例如:

代码语言:javascript复制
List collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).distinct().collect(Collectors.toList());

底层实现是LinkedHashMap,其实这个和下面的实现几乎是等价的:

代码语言:javascript复制
Set collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).collect(Collectors.toCollection(LinkedHashSet::new));

结果是一样的,靠hashcode()方法定位槽,equals()方法判断是否是同一个对象,如果是则排重被去掉,不是的话保留,通过LinkedHashMap来保留原始顺序。

但是,对于同一个对象,有时候我们排重的方式并不统一,所以最好像sorted接口一样,能让我们传入比较器,来控制如何判断两个对象相等需要排重。

例如下面的这个对象,我们有时候想按照id排重,有时候想按照name进行排重。

代码语言:javascript复制
@Data
@NoArgsConstructor
public class User {
    private int id;
    private String name;
}

解决思考

首先来实现这个distinct方法。首先,我们定义一个Key类用来代理 hashcode 还有 equals 方法:

代码语言:javascript复制
private static final class Key {
    //要比较的对象
    private final E e;
    //获取对象的hashcode的方法
    private final ToIntFunction hashCode;
    //判断两个对象是否相等的方法
    private final BiPredicate equals;

    public Key(E e, ToIntFunction hashCode,
               BiPredicate equals) {
        this.e = e;
        this.hashCode = hashCode;
        this.equals = equals;
    }

    @Override
    public int hashCode() {
        return hashCode.applyAsInt(e);
    }

    @Override
    public boolean equals(Object obj) {
        if (!(obj instanceof Key)) {
            return false;
        }
        @SuppressWarnings("unchecked")
        Key that = (Key) obj;
        return equals.test(this.e, that.e);
    }
}

然后,增加新的distinct方法:

代码语言:javascript复制
public Stream distinct (
    ToIntFunction hashCode,
    BiPredicate equals,
    //排重的时候,保留哪一个?
    BinaryOperator merger
) {
    return this.collect(Collectors.toMap(
          t -> new Key<>(t, hashCode, equals),
          Function.identity(),
          merger,
          //通过LinkedHashMap来保持原有的顺序
          LinkedHashMap::new))
          .values()
        .stream();
}

然后,这个方法如何放入 Stream 呢? 我们首先想到的就是代理 Stream 接口,最简单的实现:

代码语言:javascript复制
public class EnhancedStream implements Stream {
    private Stream delegate;

    public EnhancedStream(Stream delegate) {
        this.delegate = delegate;
    }

    private static final class Key {
        //要比较的对象
        private final E e;
        //获取对象的hashcode的方法
        private final ToIntFunction hashCode;
        //判断两个对象是否相等的方法
        private final BiPredicate equals;

        public Key(E e, ToIntFunction hashCode,
                   BiPredicate equals) {
            this.e = e;
            this.hashCode = hashCode;
            this.equals = equals;
        }

        @Override
        public int hashCode() {
            return hashCode.applyAsInt(e);
        }

        @Override
        public boolean equals(Object obj) {
            if (!(obj instanceof Key)) {
                return false;
            }
            @SuppressWarnings("unchecked")
            Key that = (Key) obj;
            return equals.test(this.e, that.e);
        }
    }

    public EnhancedStream distinct(
            ToIntFunction hashCode,
            BiPredicate equals,
            //排重的时候,保留哪一个?
            BinaryOperator merger
    ) {
        return new EnhancedStream<>(
                delegate.collect(Collectors.toMap(
                        t -> new Key<>(t, hashCode, equals),
                        Function.identity(),
                        merger,
                        //通过LinkedHashMap来保持原有的顺序
                        LinkedHashMap::new))
                        .values()
                        .stream()
        );
    }

    @Override
    public EnhancedStream filter(Predicate predicate) {
        return new EnhancedStream<>(delegate.filter(predicate));
    }

    @Override
    public  EnhancedStream map(Function mapper) {
        return new EnhancedStream<>(delegate.map(mapper));
    }

    @Override
    public IntStream mapToInt(ToIntFunction mapper) {
        return delegate.mapToInt(mapper);
    }

    @Override
    public LongStream mapToLong(ToLongFunction mapper) {
        return delegate.mapToLong(mapper);
    }

    @Override
    public DoubleStream mapToDouble(ToDoubleFunction mapper) {
        return delegate.mapToDouble(mapper);
    }

    @Override
    public  EnhancedStream flatMap(Function> mapper) {
        return new EnhancedStream<>(delegate.flatMap(mapper));
    }

    @Override
    public IntStream flatMapToInt(Function mapper) {
        return delegate.flatMapToInt(mapper);
    }

    @Override
    public LongStream flatMapToLong(Function mapper) {
        return delegate.flatMapToLong(mapper);
    }

    @Override
    public DoubleStream flatMapToDouble(Function mapper) {
        return delegate.flatMapToDouble(mapper);
    }

    @Override
    public EnhancedStream distinct() {
        return new EnhancedStream<>(delegate.distinct());
    }

    @Override
    public EnhancedStream sorted() {
        return new EnhancedStream<>(delegate.sorted());
    }

    @Override
    public EnhancedStream sorted(Comparator comparator) {
        return new EnhancedStream<>(delegate.sorted(comparator));
    }

    @Override
    public EnhancedStream peek(Consumer action) {
        return new EnhancedStream<>(delegate.peek(action));
    }

    @Override
    public EnhancedStream limit(long maxSize) {
        return new EnhancedStream<>(delegate.limit(maxSize));
    }

    @Override
    public EnhancedStream skip(long n) {
        return new EnhancedStream<>(delegate.skip(n));
    }

    @Override
    public void forEach(Consumer action) {
        delegate.forEach(action);
    }

    @Override
    public void forEachOrdered(Consumer action) {
        delegate.forEachOrdered(action);
    }

    @Override
    public Object[] toArray() {
        return delegate.toArray();
    }

    @Override
    public  A[] toArray(IntFunction generator) {
        return delegate.toArray(generator);
    }

    @Override
    public T reduce(T identity, BinaryOperator accumulator) {
        return delegate.reduce(identity, accumulator);
    }

    @Override
    public Optional reduce(BinaryOperator accumulator) {
        return delegate.reduce(accumulator);
    }

    @Override
    public  U reduce(U identity, BiFunction accumulator, BinaryOperator combiner) {
        return delegate.reduce(identity, accumulator, combiner);
    }

    @Override
    public  R collect(Supplier supplier, BiConsumer accumulator, BiConsumer combiner) {
        return delegate.collect(supplier, accumulator, combiner);
    }

    @Override
    public  R collect(Collector collector) {
        return delegate.collect(collector);
    }

    @Override
    public Optional min(Comparator comparator) {
        return delegate.min(comparator);
    }

    @Override
    public Optional max(Comparator comparator) {
        return delegate.max(comparator);
    }

    @Override
    public long count() {
        return delegate.count();
    }

    @Override
    public boolean anyMatch(Predicate predicate) {
        return delegate.anyMatch(predicate);
    }

    @Override
    public boolean allMatch(Predicate predicate) {
        return delegate.allMatch(predicate);
    }

    @Override
    public boolean noneMatch(Predicate predicate) {
        return delegate.noneMatch(predicate);
    }

    @Override
    public Optional findFirst() {
        return delegate.findFirst();
    }

    @Override
    public Optional findAny() {
        return delegate.findAny();
    }

    @Override
    public Iterator iterator() {
        return delegate.iterator();
    }

    @Override
    public Spliterator spliterator() {
        return delegate.spliterator();
    }

    @Override
    public boolean isParallel() {
        return delegate.isParallel();
    }

    @Override
    public EnhancedStream sequential() {
        return new EnhancedStream<>(delegate.sequential());
    }

    @Override
    public EnhancedStream parallel() {
        return new EnhancedStream<>(delegate.parallel());
    }

    @Override
    public EnhancedStream unordered() {
        return new EnhancedStream<>(delegate.unordered());
    }

    @Override
    public EnhancedStream onClose(Runnable closeHandler) {
        return new EnhancedStream<>(delegate.onClose(closeHandler));
    }

    @Override
    public void close() {
        delegate.close();
    }
}

测试下:

代码语言:javascript复制
public static void main(String[] args) {
    List users = new ArrayList<>() {{
        add(new User(1, "test1"));
        add(new User(2, "test1"));
        add(new User(2, "test2"));
        add(new User(3, "test3"));
        add(new User(3, "test4"));
    }};
    List collect1 = new EnhancedStream<>(users.stream()).distinct(
            User::getId,
            (u1, u2) -> u1.getId() == u2.getId(),
            (u1, u2) -> u1
    ).collect(Collectors.toList());
    List collect2 = new EnhancedStream<>(users.stream()).distinct(
            user -> user.getName().hashCode(),
            (u1, u2) -> u1.getName().equalsIgnoreCase(u2.getName()),
            (u1, u2) -> u1
    ).collect(Collectors.toList());
}

通过动态代理

上面这种实现有很多冗余代码,可以考虑使用动态代理实现,首先编写代理接口类,通过EnhancedStream继承Stream接口,增加distinct接口,并让所有返回Stream的接口返回EnhancedStream,这样才能让返回有新的distinct接口可以使用。

代码语言:javascript复制
public interface EnhancedStream extends Stream {
    EnhancedStream distinct(ToIntFunction hashCode,
                               BiPredicate equals,
                               BinaryOperator merger);
    @Override
    EnhancedStream filter(Predicate predicate);
    @Override
     EnhancedStream map(
            Function mapper);
    @Override
     EnhancedStream flatMap(
            Function> mapper);
    @Override
    EnhancedStream distinct();
    @Override
    EnhancedStream sorted();
    @Override
    EnhancedStream sorted(Comparator comparator);
    @Override
    EnhancedStream peek(Consumer action);
    @Override
    EnhancedStream limit(long maxSize);
    @Override
    EnhancedStream skip(long n);
    @Override
    EnhancedStream takeWhile(Predicate predicate);
    @Override
    EnhancedStream dropWhile(Predicate predicate);
    @Override
    EnhancedStream sequential();
    @Override
    EnhancedStream parallel();
    @Override
    EnhancedStream unordered();
    @Override
    EnhancedStream onClose(Runnable closeHandler);
}

然后,编写代理类EnhancedStreamHandler实现方法代理:

代码语言:javascript复制
public class EnhancedStreamHandler implements InvocationHandler {
    private Stream delegate;

    public EnhancedStreamHandler(Stream delegate) {
        this.delegate = delegate;
    }

    private static final Method ENHANCED_DISTINCT;
    static {
        try {
            ENHANCED_DISTINCT = EnhancedStream.class.getMethod(
                    "distinct", ToIntFunction.class, BiPredicate.class,
                    BinaryOperator.class
            );
        } catch (NoSuchMethodException e) {
            throw new Error(e);
        }
    }

    /**
     * 将EnhancedStream的方法与Stream的方法一一对应
     */
    private static final Map METHOD_MAP =
            Stream.of(EnhancedStream.class.getMethods())
                    .filter(m -> !m.equals(ENHANCED_DISTINCT))
                    .filter(m -> !Modifier.isStatic(m.getModifiers()))
                    .collect(Collectors.toUnmodifiableMap(
                            Function.identity(),
                            m -> {
                                try {
                                    return Stream.class.getMethod(
                                            m.getName(), m.getParameterTypes());
                                } catch (NoSuchMethodException e) {
                                    throw new Error(e);
                                }
                            }));


    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        if (method.equals(ENHANCED_DISTINCT)) {
            //调用方法为扩展方法distinct
            return distinct(
                    (EnhancedStream) proxy,
                    (ToIntFunction) args[0],
                    (BiPredicate) args[1],
                    (BinaryOperator) args[2]);
        } else if (method.getReturnType() == EnhancedStream.class) {
            //对于返回类型为EnhancedStream的,证明是代理的方法调用,走代理
            Method match = METHOD_MAP.get(method);
            //更相信代理对象为新的Stream
            this.delegate = (Stream) match.invoke(this.delegate, args);
            return proxy;
        } else {
            //否则,直接用代理类调用
            return method.invoke(this.delegate, args);
        }
    }

    private static final class Key {
        private final E e;
        private final ToIntFunction hashCode;
        private final BiPredicate equals;

        public Key(E e, ToIntFunction hashCode,
                   BiPredicate equals) {
            this.e = e;
            this.hashCode = hashCode;
            this.equals = equals;
        }

        @Override
        public int hashCode() {
            return hashCode.applyAsInt(e);
        }

        @Override
        public boolean equals(Object obj) {
            if (!(obj instanceof Key)) {
                return false;
            }
            @SuppressWarnings("unchecked")
            Key that = (Key) obj;
            return equals.test(this.e, that.e);
        }
    }

    private EnhancedStream distinct(EnhancedStream proxy,
                                       ToIntFunction hashCode,
                                       BiPredicate equals,
                                       BinaryOperator merger) {
        delegate = delegate.collect(Collectors.toMap(
                t -> new Key<>(t, hashCode, equals),
                Function.identity(),
                merger,
                //使用LinkedHashMap,保持入参原始顺序
                LinkedHashMap::new))
                .values()
                .stream();
        return proxy;
    }
}

最后编写工厂类,生成EnhancedStream代理类:

代码语言:javascript复制
public class EnhancedStreamFactory {
    public static  EnhancedStream newEnhancedStream(Stream stream) {
        return (EnhancedStream) Proxy.newProxyInstance(
                //必须用EnhancedStream的classLoader,不能用Stream的,因为Stream是jdk的类,ClassLoader是rootClassLoader
                EnhancedStream.class.getClassLoader(),
                //代理接口
                new Class[] {EnhancedStream.class},
                //代理类
                new EnhancedStreamHandler<>(stream)
        );
    }
}

这样,代码看上去更优雅了,就算 JDK 以后扩展更多方法,这里也可不用修改

0 人点赞