遇到的问题
Java 8 开始引入了 Stream, 其中的 api 一直在不断的优化更新完善,Java 9 中更是引入了 ofNullable
还有 takeWhile
和 dropWhile
这两个关键 api。有时候,我们想对 Stream 中的对象进行排重,默认的可以用 distinct 这个 api,例如:
List collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).distinct().collect(Collectors.toList());
底层实现是LinkedHashMap
,其实这个和下面的实现几乎是等价的:
Set collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).collect(Collectors.toCollection(LinkedHashSet::new));
结果是一样的,靠hashcode()
方法定位槽,equals()
方法判断是否是同一个对象,如果是则排重被去掉,不是的话保留,通过LinkedHashMap来保留原始顺序。
但是,对于同一个对象,有时候我们排重的方式并不统一,所以最好像sorted
接口一样,能让我们传入比较器,来控制如何判断两个对象相等需要排重。
例如下面的这个对象,我们有时候想按照id
排重,有时候想按照name
进行排重。
@Data
@NoArgsConstructor
public class User {
private int id;
private String name;
}
解决思考
首先来实现这个distinct
方法。首先,我们定义一个Key
类用来代理 hashcode 还有 equals 方法:
private static final class Key {
//要比较的对象
private final E e;
//获取对象的hashcode的方法
private final ToIntFunction hashCode;
//判断两个对象是否相等的方法
private final BiPredicate equals;
public Key(E e, ToIntFunction hashCode,
BiPredicate equals) {
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
}
@Override
public int hashCode() {
return hashCode.applyAsInt(e);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Key)) {
return false;
}
@SuppressWarnings("unchecked")
Key that = (Key) obj;
return equals.test(this.e, that.e);
}
}
然后,增加新的distinct
方法:
public Stream distinct (
ToIntFunction hashCode,
BiPredicate equals,
//排重的时候,保留哪一个?
BinaryOperator merger
) {
return this.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//通过LinkedHashMap来保持原有的顺序
LinkedHashMap::new))
.values()
.stream();
}
然后,这个方法如何放入 Stream 呢? 我们首先想到的就是代理 Stream
接口,最简单的实现:
public class EnhancedStream implements Stream {
private Stream delegate;
public EnhancedStream(Stream delegate) {
this.delegate = delegate;
}
private static final class Key {
//要比较的对象
private final E e;
//获取对象的hashcode的方法
private final ToIntFunction hashCode;
//判断两个对象是否相等的方法
private final BiPredicate equals;
public Key(E e, ToIntFunction hashCode,
BiPredicate equals) {
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
}
@Override
public int hashCode() {
return hashCode.applyAsInt(e);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Key)) {
return false;
}
@SuppressWarnings("unchecked")
Key that = (Key) obj;
return equals.test(this.e, that.e);
}
}
public EnhancedStream distinct(
ToIntFunction hashCode,
BiPredicate equals,
//排重的时候,保留哪一个?
BinaryOperator merger
) {
return new EnhancedStream<>(
delegate.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//通过LinkedHashMap来保持原有的顺序
LinkedHashMap::new))
.values()
.stream()
);
}
@Override
public EnhancedStream filter(Predicate predicate) {
return new EnhancedStream<>(delegate.filter(predicate));
}
@Override
public EnhancedStream map(Function mapper) {
return new EnhancedStream<>(delegate.map(mapper));
}
@Override
public IntStream mapToInt(ToIntFunction mapper) {
return delegate.mapToInt(mapper);
}
@Override
public LongStream mapToLong(ToLongFunction mapper) {
return delegate.mapToLong(mapper);
}
@Override
public DoubleStream mapToDouble(ToDoubleFunction mapper) {
return delegate.mapToDouble(mapper);
}
@Override
public EnhancedStream flatMap(Function> mapper) {
return new EnhancedStream<>(delegate.flatMap(mapper));
}
@Override
public IntStream flatMapToInt(Function mapper) {
return delegate.flatMapToInt(mapper);
}
@Override
public LongStream flatMapToLong(Function mapper) {
return delegate.flatMapToLong(mapper);
}
@Override
public DoubleStream flatMapToDouble(Function mapper) {
return delegate.flatMapToDouble(mapper);
}
@Override
public EnhancedStream distinct() {
return new EnhancedStream<>(delegate.distinct());
}
@Override
public EnhancedStream sorted() {
return new EnhancedStream<>(delegate.sorted());
}
@Override
public EnhancedStream sorted(Comparator comparator) {
return new EnhancedStream<>(delegate.sorted(comparator));
}
@Override
public EnhancedStream peek(Consumer action) {
return new EnhancedStream<>(delegate.peek(action));
}
@Override
public EnhancedStream limit(long maxSize) {
return new EnhancedStream<>(delegate.limit(maxSize));
}
@Override
public EnhancedStream skip(long n) {
return new EnhancedStream<>(delegate.skip(n));
}
@Override
public void forEach(Consumer action) {
delegate.forEach(action);
}
@Override
public void forEachOrdered(Consumer action) {
delegate.forEachOrdered(action);
}
@Override
public Object[] toArray() {
return delegate.toArray();
}
@Override
public A[] toArray(IntFunction generator) {
return delegate.toArray(generator);
}
@Override
public T reduce(T identity, BinaryOperator accumulator) {
return delegate.reduce(identity, accumulator);
}
@Override
public Optional reduce(BinaryOperator accumulator) {
return delegate.reduce(accumulator);
}
@Override
public U reduce(U identity, BiFunction accumulator, BinaryOperator combiner) {
return delegate.reduce(identity, accumulator, combiner);
}
@Override
public R collect(Supplier supplier, BiConsumer accumulator, BiConsumer combiner) {
return delegate.collect(supplier, accumulator, combiner);
}
@Override
public R collect(Collector collector) {
return delegate.collect(collector);
}
@Override
public Optional min(Comparator comparator) {
return delegate.min(comparator);
}
@Override
public Optional max(Comparator comparator) {
return delegate.max(comparator);
}
@Override
public long count() {
return delegate.count();
}
@Override
public boolean anyMatch(Predicate predicate) {
return delegate.anyMatch(predicate);
}
@Override
public boolean allMatch(Predicate predicate) {
return delegate.allMatch(predicate);
}
@Override
public boolean noneMatch(Predicate predicate) {
return delegate.noneMatch(predicate);
}
@Override
public Optional findFirst() {
return delegate.findFirst();
}
@Override
public Optional findAny() {
return delegate.findAny();
}
@Override
public Iterator iterator() {
return delegate.iterator();
}
@Override
public Spliterator spliterator() {
return delegate.spliterator();
}
@Override
public boolean isParallel() {
return delegate.isParallel();
}
@Override
public EnhancedStream sequential() {
return new EnhancedStream<>(delegate.sequential());
}
@Override
public EnhancedStream parallel() {
return new EnhancedStream<>(delegate.parallel());
}
@Override
public EnhancedStream unordered() {
return new EnhancedStream<>(delegate.unordered());
}
@Override
public EnhancedStream onClose(Runnable closeHandler) {
return new EnhancedStream<>(delegate.onClose(closeHandler));
}
@Override
public void close() {
delegate.close();
}
}
测试下:
代码语言:javascript复制public static void main(String[] args) {
List users = new ArrayList<>() {{
add(new User(1, "test1"));
add(new User(2, "test1"));
add(new User(2, "test2"));
add(new User(3, "test3"));
add(new User(3, "test4"));
}};
List collect1 = new EnhancedStream<>(users.stream()).distinct(
User::getId,
(u1, u2) -> u1.getId() == u2.getId(),
(u1, u2) -> u1
).collect(Collectors.toList());
List collect2 = new EnhancedStream<>(users.stream()).distinct(
user -> user.getName().hashCode(),
(u1, u2) -> u1.getName().equalsIgnoreCase(u2.getName()),
(u1, u2) -> u1
).collect(Collectors.toList());
}
通过动态代理
上面这种实现有很多冗余代码,可以考虑使用动态代理实现,首先编写代理接口类,通过EnhancedStream
继承Stream
接口,增加distinct
接口,并让所有返回Stream
的接口返回EnhancedStream
,这样才能让返回有新的distinct
接口可以使用。
public interface EnhancedStream extends Stream {
EnhancedStream distinct(ToIntFunction hashCode,
BiPredicate equals,
BinaryOperator merger);
@Override
EnhancedStream filter(Predicate predicate);
@Override
EnhancedStream map(
Function mapper);
@Override
EnhancedStream flatMap(
Function> mapper);
@Override
EnhancedStream distinct();
@Override
EnhancedStream sorted();
@Override
EnhancedStream sorted(Comparator comparator);
@Override
EnhancedStream peek(Consumer action);
@Override
EnhancedStream limit(long maxSize);
@Override
EnhancedStream skip(long n);
@Override
EnhancedStream takeWhile(Predicate predicate);
@Override
EnhancedStream dropWhile(Predicate predicate);
@Override
EnhancedStream sequential();
@Override
EnhancedStream parallel();
@Override
EnhancedStream unordered();
@Override
EnhancedStream onClose(Runnable closeHandler);
}
然后,编写代理类EnhancedStreamHandler
实现方法代理:
public class EnhancedStreamHandler implements InvocationHandler {
private Stream delegate;
public EnhancedStreamHandler(Stream delegate) {
this.delegate = delegate;
}
private static final Method ENHANCED_DISTINCT;
static {
try {
ENHANCED_DISTINCT = EnhancedStream.class.getMethod(
"distinct", ToIntFunction.class, BiPredicate.class,
BinaryOperator.class
);
} catch (NoSuchMethodException e) {
throw new Error(e);
}
}
/**
* 将EnhancedStream的方法与Stream的方法一一对应
*/
private static final Map METHOD_MAP =
Stream.of(EnhancedStream.class.getMethods())
.filter(m -> !m.equals(ENHANCED_DISTINCT))
.filter(m -> !Modifier.isStatic(m.getModifiers()))
.collect(Collectors.toUnmodifiableMap(
Function.identity(),
m -> {
try {
return Stream.class.getMethod(
m.getName(), m.getParameterTypes());
} catch (NoSuchMethodException e) {
throw new Error(e);
}
}));
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (method.equals(ENHANCED_DISTINCT)) {
//调用方法为扩展方法distinct
return distinct(
(EnhancedStream) proxy,
(ToIntFunction) args[0],
(BiPredicate) args[1],
(BinaryOperator) args[2]);
} else if (method.getReturnType() == EnhancedStream.class) {
//对于返回类型为EnhancedStream的,证明是代理的方法调用,走代理
Method match = METHOD_MAP.get(method);
//更相信代理对象为新的Stream
this.delegate = (Stream) match.invoke(this.delegate, args);
return proxy;
} else {
//否则,直接用代理类调用
return method.invoke(this.delegate, args);
}
}
private static final class Key {
private final E e;
private final ToIntFunction hashCode;
private final BiPredicate equals;
public Key(E e, ToIntFunction hashCode,
BiPredicate equals) {
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
}
@Override
public int hashCode() {
return hashCode.applyAsInt(e);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Key)) {
return false;
}
@SuppressWarnings("unchecked")
Key that = (Key) obj;
return equals.test(this.e, that.e);
}
}
private EnhancedStream distinct(EnhancedStream proxy,
ToIntFunction hashCode,
BiPredicate equals,
BinaryOperator merger) {
delegate = delegate.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//使用LinkedHashMap,保持入参原始顺序
LinkedHashMap::new))
.values()
.stream();
return proxy;
}
}
最后编写工厂类,生成EnhancedStream
代理类:
public class EnhancedStreamFactory {
public static EnhancedStream newEnhancedStream(Stream stream) {
return (EnhancedStream) Proxy.newProxyInstance(
//必须用EnhancedStream的classLoader,不能用Stream的,因为Stream是jdk的类,ClassLoader是rootClassLoader
EnhancedStream.class.getClassLoader(),
//代理接口
new Class[] {EnhancedStream.class},
//代理类
new EnhancedStreamHandler<>(stream)
);
}
}
这样,代码看上去更优雅了,就算 JDK 以后扩展更多方法,这里也可不用修改