增强 Stream 接口的 distinct 方法的一些思考
Posted 干货满满张哈希
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了增强 Stream 接口的 distinct 方法的一些思考相关的知识,希望对你有一定的参考价值。
遇到的问题
Java 8 开始引入了 Stream, 其中的 api 一直在不断的优化更新完善,Java 9 中更是引入了 ofNullable
还有 takeWhile
和 dropWhile
这两个关键 api。有时候,我们想对 Stream 中的对象进行排重,默认的可以用 distinct 这个 api,例如:
List<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).distinct().collect(Collectors.toList());
底层实现是LinkedHashMap
,其实这个和下面的实现几乎是等价的:
Set<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).collect(Collectors.toCollection(LinkedHashSet::new));
结果是一样的,靠hashcode()
方法定位槽,equals()
方法判断是否是同一个对象,如果是则排重被去掉,不是的话保留,通过LinkedHashMap来保留原始顺序。
但是,对于同一个对象,有时候我们排重的方式并不统一,所以最好像sorted
接口一样,能让我们传入比较器,来控制如何判断两个对象相等需要排重。
例如下面的这个对象,我们有时候想按照id
排重,有时候想按照name
进行排重。
@Data
@NoArgsConstructor
public class User
private int id;
private String name;
解决思考
首先来实现这个distinct
方法。首先,我们定义一个Key
类用来代理 hashcode 还有 equals 方法:
private static final class Key<E>
//要比较的对象
private final E e;
//获取对象的hashcode的方法
private final ToIntFunction<E> hashCode;
//判断两个对象是否相等的方法
private final BiPredicate<E, E> equals;
public Key(E e, ToIntFunction<E> hashCode,
BiPredicate<E, E> equals)
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
@Override
public int hashCode()
return hashCode.applyAsInt(e);
@Override
public boolean equals(Object obj)
if (!(obj instanceof Key))
return false;
@SuppressWarnings("unchecked")
Key<E> that = (Key<E>) obj;
return equals.test(this.e, that.e);
然后,增加新的distinct
方法:
public Stream<T> distinct (
ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
//排重的时候,保留哪一个?
BinaryOperator<T> merger
)
return this.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//通过LinkedHashMap来保持原有的顺序
LinkedHashMap::new))
.values()
.stream();
然后,这个方法如何放入 Stream 呢? 我们首先想到的就是代理 Stream
接口,最简单的实现:
public class EnhancedStream<T> implements Stream<T>
private Stream<T> delegate;
public EnhancedStream(Stream<T> delegate)
this.delegate = delegate;
private static final class Key<E>
//要比较的对象
private final E e;
//获取对象的hashcode的方法
private final ToIntFunction<E> hashCode;
//判断两个对象是否相等的方法
private final BiPredicate<E, E> equals;
public Key(E e, ToIntFunction<E> hashCode,
BiPredicate<E, E> equals)
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
@Override
public int hashCode()
return hashCode.applyAsInt(e);
@Override
public boolean equals(Object obj)
if (!(obj instanceof Key))
return false;
@SuppressWarnings("unchecked")
Key<E> that = (Key<E>) obj;
return equals.test(this.e, that.e);
public EnhancedStream<T> distinct(
ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
//排重的时候,保留哪一个?
BinaryOperator<T> merger
)
return new EnhancedStream<>(
delegate.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//通过LinkedHashMap来保持原有的顺序
LinkedHashMap::new))
.values()
.stream()
);
@Override
public EnhancedStream<T> filter(Predicate<? super T> predicate)
return new EnhancedStream<>(delegate.filter(predicate));
@Override
public <R> EnhancedStream<R> map(Function<? super T, ? extends R> mapper)
return new EnhancedStream<>(delegate.map(mapper));
@Override
public IntStream mapToInt(ToIntFunction<? super T> mapper)
return delegate.mapToInt(mapper);
@Override
public LongStream mapToLong(ToLongFunction<? super T> mapper)
return delegate.mapToLong(mapper);
@Override
public DoubleStream mapToDouble(ToDoubleFunction<? super T> mapper)
return delegate.mapToDouble(mapper);
@Override
public <R> EnhancedStream<R> flatMap(Function<? super T, ? extends Stream<? extends R>> mapper)
return new EnhancedStream<>(delegate.flatMap(mapper));
@Override
public IntStream flatMapToInt(Function<? super T, ? extends IntStream> mapper)
return delegate.flatMapToInt(mapper);
@Override
public LongStream flatMapToLong(Function<? super T, ? extends LongStream> mapper)
return delegate.flatMapToLong(mapper);
@Override
public DoubleStream flatMapToDouble(Function<? super T, ? extends DoubleStream> mapper)
return delegate.flatMapToDouble(mapper);
@Override
public EnhancedStream<T> distinct()
return new EnhancedStream<>(delegate.distinct());
@Override
public EnhancedStream<T> sorted()
return new EnhancedStream<>(delegate.sorted());
@Override
public EnhancedStream<T> sorted(Comparator<? super T> comparator)
return new EnhancedStream<>(delegate.sorted(comparator));
@Override
public EnhancedStream<T> peek(Consumer<? super T> action)
return new EnhancedStream<>(delegate.peek(action));
@Override
public EnhancedStream<T> limit(long maxSize)
return new EnhancedStream<>(delegate.limit(maxSize));
@Override
public EnhancedStream<T> skip(long n)
return new EnhancedStream<>(delegate.skip(n));
@Override
public void forEach(Consumer<? super T> action)
delegate.forEach(action);
@Override
public void forEachOrdered(Consumer<? super T> action)
delegate.forEachOrdered(action);
@Override
public Object[] toArray()
return delegate.toArray();
@Override
public <A> A[] toArray(IntFunction<A[]> generator)
return delegate.toArray(generator);
@Override
public T reduce(T identity, BinaryOperator<T> accumulator)
return delegate.reduce(identity, accumulator);
@Override
public Optional<T> reduce(BinaryOperator<T> accumulator)
return delegate.reduce(accumulator);
@Override
public <U> U reduce(U identity, BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner)
return delegate.reduce(identity, accumulator, combiner);
@Override
public <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super T> accumulator, BiConsumer<R, R> combiner)
return delegate.collect(supplier, accumulator, combiner);
@Override
public <R, A> R collect(Collector<? super T, A, R> collector)
return delegate.collect(collector);
@Override
public Optional<T> min(Comparator<? super T> comparator)
return delegate.min(comparator);
@Override
public Optional<T> max(Comparator<? super T> comparator)
return delegate.max(comparator);
@Override
public long count()
return delegate.count();
@Override
public boolean anyMatch(Predicate<? super T> predicate)
return delegate.anyMatch(predicate);
@Override
public boolean allMatch(Predicate<? super T> predicate)
return delegate.allMatch(predicate);
@Override
public boolean noneMatch(Predicate<? super T> predicate)
return delegate.noneMatch(predicate);
@Override
public Optional<T> findFirst()
return delegate.findFirst();
@Override
public Optional<T> findAny()
return delegate.findAny();
@Override
public Iterator<T> iterator()
return delegate.iterator();
@Override
public Spliterator<T> spliterator()
return delegate.spliterator();
@Override
public boolean isParallel()
return delegate.isParallel();
@Override
public EnhancedStream<T> sequential()
return new EnhancedStream<>(delegate.sequential());
@Override
public EnhancedStream<T> parallel()
return new EnhancedStream<>(delegate.parallel());
@Override
public EnhancedStream<T> unordered()
return new EnhancedStream<>(delegate.unordered());
@Override
public EnhancedStream<T> onClose(Runnable closeHandler)
return new EnhancedStream<>(delegate.onClose(closeHandler));
@Override
public void close()
delegate.close();
测试下:
public static void main(String[] args)
List<User> users = new ArrayList<>()
add(new User(1, "test1"));
add(new User(2, "test1"));
add(new User(2, "test2"));
add(new User(3, "test3"));
add(new User(3, "test4"));
;
List<User> collect1 = new EnhancedStream<>(users.stream()).distinct(
User::getId,
(u1, u2) -> u1.getId() == u2.getId(),
(u1, u2) -> u1
).collect(Collectors.toList());
List<User> collect2 = new EnhancedStream<>(users.stream()).distinct(
user -> user.getName().hashCode(),
(u1, u2) -> u1.getName().equalsIgnoreCase(u2.getName()),
(u1, u2) -> u1
).collect(Collectors.toList());
通过动态代理
上面这种实现有很多冗余代码,可以考虑使用动态代理实现,首先编写代理接口类,通过EnhancedStream
继承Stream
接口,增加distinct
接口,并让所有返回Stream
的接口返回EnhancedStream
,这样才能让返回有新的distinct
接口可以使用。
public interface EnhancedStream<T> extends Stream<T>
EnhancedStream<T> distinct(ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
BinaryOperator<T> merger);
@Override
EnhancedStream<T> filter(Predicate<? super T> predicate);
@Override
<R> EnhancedStream<R> map(
Function<? super T, ? extends R> mapper);
@Override
<R> EnhancedStream<R> flatMap(
Function<? super T, ? extends Stream<? extends R>> mapper);
@Override
EnhancedStream<T> distinct();
@Override
EnhancedStream<T> sorted();
@Override
EnhancedStream<T> sorted(Comparator<? super T> comparator);
@Override
EnhancedStream<T> peek(Consumer<? super T> action);
@Override
EnhancedStream<T> limit(long maxSize);
@Override
EnhancedStream<T> skip(long n);
@Override
EnhancedStream<T> takeWhile(Predicate<? super T> predicate);
@Override
EnhancedStream<T> dropWhile(Predicate<? super T> predicate);
@Override
EnhancedStream<T> sequential();
@Override
EnhancedStream<T> parallel();
@Override
EnhancedStream<T> unordered();
@Override
EnhancedStream<T> onClose(Runnable closeHandler);
然后,编写代理类EnhancedStreamHandler
实现方法代理:
public class EnhancedStreamHandler<T> implements InvocationHandler
private Stream<T> delegate;
public EnhancedStreamHandler(Stream<T> delegate)
this.delegate = delegate;
private static final Method ENHANCED_DISTINCT;
static
try
ENHANCED_DISTINCT = EnhancedStream.class.getMethod(
"distinct", ToIntFunction.class, BiPredicate.class,
BinaryOperator.class
);
catch (NoSuchMethodException e)
throw new Error(e);
/**
* 将EnhancedStream的方法与Stream的方法一一对应
*/
private static final Map<Method, Method> METHOD_MAP =
Stream.of(EnhancedStream.class.getMethods())
.filter(m -> !m.equals(ENHANCED_DISTINCT))
.filter(m -> !Modifier.isStatic(m.getModifiers()))
.collect(Collectors.toUnmodifiableMap(
Function.identity(),
m ->
try
return Stream.class.getMethod(
m.getName(), m.getParameterTypes());
catch (NoSuchMethodException e)
throw new Error(e);
));
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable
if (method.equals(ENHANCED_DISTINCT))
//调用方法为扩展方法distinct
return distinct(
(EnhancedStream<T>) proxy,
(ToIntFunction<T>) args[0],
(BiPredicate<T, T>) args[1],
(BinaryOperator<T>) args[2]);
else if (method.getReturnType() == EnhancedStream.class)
//对于返回类型为EnhancedStream的,证明是代理的方法调用,走代理
Method match = METHOD_MAP.get(method);
//更相信代理对象为新的Stream
this.delegate = (Stream) match.invoke(this.delegate, args);
return proxy;
else
//否则,直接用代理类调用
return method.invoke(this.delegate, args);
private static final class Key<E>
private final E e;
private final ToIntFunction<E> hashCode;
private final BiPredicate<E, E> equals;
public Key(E e, ToIntFunction<E> hashCode,
BiPredicate<E, E> equals)
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
@Override
public int hashCode()
return hashCode.applyAsInt(e);
@Override
public boolean equals(Object obj)
if (!(obj instanceof Key))
return false;
@SuppressWarnings("unchecked")
Key<E> that = (Key<E>) obj;
return equals.test(this.e, that.e);
private EnhancedStream<T> distinct(EnhancedStream<T> proxy,
ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
BinaryOperator<T> merger)
delegate = delegate.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//使用LinkedHashMap,保持入参原始顺序
LinkedHashMap::new))
.values()
.stream();
return proxy;
最后编写工厂类,生成EnhancedStream
代理类:
public class EnhancedStreamFactory
public static <E> EnhancedStream<E> newEnhancedStream(Stream<E> stream)
return (EnhancedStream<E>) Proxy.newProxyInstance(
//必须用EnhancedStream的classLoader,不能用Stream的,因为Stream是jdk的类,ClassLoader是rootClassLoader
EnhancedStream.class.getClassLoader(),
//代理接口
new Class<?>[] EnhancedStream.class,
//代理类
new EnhancedStreamHandler<>(stream)
);
这样,代码看上去更优雅了,就算 JDK 以后扩展更多方法,这里也可不用修改
以上是关于增强 Stream 接口的 distinct 方法的一些思考的主要内容,如果未能解决你的问题,请参考以下文章
精细篇Java8强大的stream API接口大全(代码优雅之道)