增强 Stream 接口的 distinct 方法的一些思考

Posted 干货满满张哈希

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了增强 Stream 接口的 distinct 方法的一些思考相关的知识,希望对你有一定的参考价值。

遇到的问题

Java 8 开始引入了 Stream, 其中的 api 一直在不断的优化更新完善,Java 9 中更是引入了 ofNullable 还有 takeWhiledropWhile 这两个关键 api。有时候,我们想对 Stream 中的对象进行排重,默认的可以用 distinct 这个 api,例如:

List<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).distinct().collect(Collectors.toList());

底层实现是LinkedHashMap,其实这个和下面的实现几乎是等价的:

Set<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).collect(Collectors.toCollection(LinkedHashSet::new));

结果是一样的,靠hashcode()方法定位槽,equals()方法判断是否是同一个对象,如果是则排重被去掉,不是的话保留,通过LinkedHashMap来保留原始顺序。

但是,对于同一个对象,有时候我们排重的方式并不统一,所以最好像sorted接口一样,能让我们传入比较器,来控制如何判断两个对象相等需要排重。

例如下面的这个对象,我们有时候想按照id排重,有时候想按照name进行排重。

@Data
@NoArgsConstructor
public class User 
    private int id;
    private String name;

解决思考

首先来实现这个distinct方法。首先,我们定义一个Key类用来代理 hashcode 还有 equals 方法:

private static final class Key<E> 
    //要比较的对象
    private final E e;
    //获取对象的hashcode的方法
    private final ToIntFunction<E> hashCode;
    //判断两个对象是否相等的方法
    private final BiPredicate<E, E> equals;

    public Key(E e, ToIntFunction<E> hashCode,
               BiPredicate<E, E> equals) 
        this.e = e;
        this.hashCode = hashCode;
        this.equals = equals;
    

    @Override
    public int hashCode() 
        return hashCode.applyAsInt(e);
    

    @Override
    public boolean equals(Object obj) 
        if (!(obj instanceof Key)) 
            return false;
        
        @SuppressWarnings("unchecked")
        Key<E> that = (Key<E>) obj;
        return equals.test(this.e, that.e);
    

然后,增加新的distinct方法:

public Stream<T> distinct (
    ToIntFunction<T> hashCode,
    BiPredicate<T, T> equals,
    //排重的时候,保留哪一个?
    BinaryOperator<T> merger
) 
    return this.collect(Collectors.toMap(
          t -> new Key<>(t, hashCode, equals),
          Function.identity(),
          merger,
          //通过LinkedHashMap来保持原有的顺序
          LinkedHashMap::new))
          .values()
        .stream();

然后,这个方法如何放入 Stream 呢? 我们首先想到的就是代理 Stream 接口,最简单的实现:

public class EnhancedStream<T> implements Stream<T> 
    private Stream<T> delegate;

    public EnhancedStream(Stream<T> delegate) 
        this.delegate = delegate;
    

    private static final class Key<E> 
        //要比较的对象
        private final E e;
        //获取对象的hashcode的方法
        private final ToIntFunction<E> hashCode;
        //判断两个对象是否相等的方法
        private final BiPredicate<E, E> equals;

        public Key(E e, ToIntFunction<E> hashCode,
                   BiPredicate<E, E> equals) 
            this.e = e;
            this.hashCode = hashCode;
            this.equals = equals;
        

        @Override
        public int hashCode() 
            return hashCode.applyAsInt(e);
        

        @Override
        public boolean equals(Object obj) 
            if (!(obj instanceof Key)) 
                return false;
            
            @SuppressWarnings("unchecked")
            Key<E> that = (Key<E>) obj;
            return equals.test(this.e, that.e);
        
    

    public EnhancedStream<T> distinct(
            ToIntFunction<T> hashCode,
            BiPredicate<T, T> equals,
            //排重的时候,保留哪一个?
            BinaryOperator<T> merger
    ) 
        return new EnhancedStream<>(
                delegate.collect(Collectors.toMap(
                        t -> new Key<>(t, hashCode, equals),
                        Function.identity(),
                        merger,
                        //通过LinkedHashMap来保持原有的顺序
                        LinkedHashMap::new))
                        .values()
                        .stream()
        );
    

    @Override
    public EnhancedStream<T> filter(Predicate<? super T> predicate) 
        return new EnhancedStream<>(delegate.filter(predicate));
    

    @Override
    public <R> EnhancedStream<R> map(Function<? super T, ? extends R> mapper) 
        return new EnhancedStream<>(delegate.map(mapper));
    

    @Override
    public IntStream mapToInt(ToIntFunction<? super T> mapper) 
        return delegate.mapToInt(mapper);
    

    @Override
    public LongStream mapToLong(ToLongFunction<? super T> mapper) 
        return delegate.mapToLong(mapper);
    

    @Override
    public DoubleStream mapToDouble(ToDoubleFunction<? super T> mapper) 
        return delegate.mapToDouble(mapper);
    

    @Override
    public <R> EnhancedStream<R> flatMap(Function<? super T, ? extends Stream<? extends R>> mapper) 
        return new EnhancedStream<>(delegate.flatMap(mapper));
    

    @Override
    public IntStream flatMapToInt(Function<? super T, ? extends IntStream> mapper) 
        return delegate.flatMapToInt(mapper);
    

    @Override
    public LongStream flatMapToLong(Function<? super T, ? extends LongStream> mapper) 
        return delegate.flatMapToLong(mapper);
    

    @Override
    public DoubleStream flatMapToDouble(Function<? super T, ? extends DoubleStream> mapper) 
        return delegate.flatMapToDouble(mapper);
    

    @Override
    public EnhancedStream<T> distinct() 
        return new EnhancedStream<>(delegate.distinct());
    

    @Override
    public EnhancedStream<T> sorted() 
        return new EnhancedStream<>(delegate.sorted());
    

    @Override
    public EnhancedStream<T> sorted(Comparator<? super T> comparator) 
        return new EnhancedStream<>(delegate.sorted(comparator));
    

    @Override
    public EnhancedStream<T> peek(Consumer<? super T> action) 
        return new EnhancedStream<>(delegate.peek(action));
    

    @Override
    public EnhancedStream<T> limit(long maxSize) 
        return new EnhancedStream<>(delegate.limit(maxSize));
    

    @Override
    public EnhancedStream<T> skip(long n) 
        return new EnhancedStream<>(delegate.skip(n));
    

    @Override
    public void forEach(Consumer<? super T> action) 
        delegate.forEach(action);
    

    @Override
    public void forEachOrdered(Consumer<? super T> action) 
        delegate.forEachOrdered(action);
    

    @Override
    public Object[] toArray() 
        return delegate.toArray();
    

    @Override
    public <A> A[] toArray(IntFunction<A[]> generator) 
        return delegate.toArray(generator);
    

    @Override
    public T reduce(T identity, BinaryOperator<T> accumulator) 
        return delegate.reduce(identity, accumulator);
    

    @Override
    public Optional<T> reduce(BinaryOperator<T> accumulator) 
        return delegate.reduce(accumulator);
    

    @Override
    public <U> U reduce(U identity, BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner) 
        return delegate.reduce(identity, accumulator, combiner);
    

    @Override
    public <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super T> accumulator, BiConsumer<R, R> combiner) 
        return delegate.collect(supplier, accumulator, combiner);
    

    @Override
    public <R, A> R collect(Collector<? super T, A, R> collector) 
        return delegate.collect(collector);
    

    @Override
    public Optional<T> min(Comparator<? super T> comparator) 
        return delegate.min(comparator);
    

    @Override
    public Optional<T> max(Comparator<? super T> comparator) 
        return delegate.max(comparator);
    

    @Override
    public long count() 
        return delegate.count();
    

    @Override
    public boolean anyMatch(Predicate<? super T> predicate) 
        return delegate.anyMatch(predicate);
    

    @Override
    public boolean allMatch(Predicate<? super T> predicate) 
        return delegate.allMatch(predicate);
    

    @Override
    public boolean noneMatch(Predicate<? super T> predicate) 
        return delegate.noneMatch(predicate);
    

    @Override
    public Optional<T> findFirst() 
        return delegate.findFirst();
    

    @Override
    public Optional<T> findAny() 
        return delegate.findAny();
    

    @Override
    public Iterator<T> iterator() 
        return delegate.iterator();
    

    @Override
    public Spliterator<T> spliterator() 
        return delegate.spliterator();
    

    @Override
    public boolean isParallel() 
        return delegate.isParallel();
    

    @Override
    public EnhancedStream<T> sequential() 
        return new EnhancedStream<>(delegate.sequential());
    

    @Override
    public EnhancedStream<T> parallel() 
        return new EnhancedStream<>(delegate.parallel());
    

    @Override
    public EnhancedStream<T> unordered() 
        return new EnhancedStream<>(delegate.unordered());
    

    @Override
    public EnhancedStream<T> onClose(Runnable closeHandler) 
        return new EnhancedStream<>(delegate.onClose(closeHandler));
    

    @Override
    public void close() 
        delegate.close();
    


测试下:

public static void main(String[] args) 
    List<User> users = new ArrayList<>() 
        add(new User(1, "test1"));
        add(new User(2, "test1"));
        add(new User(2, "test2"));
        add(new User(3, "test3"));
        add(new User(3, "test4"));
    ;
    List<User> collect1 = new EnhancedStream<>(users.stream()).distinct(
            User::getId,
            (u1, u2) -> u1.getId() == u2.getId(),
            (u1, u2) -> u1
    ).collect(Collectors.toList());
    List<User> collect2 = new EnhancedStream<>(users.stream()).distinct(
            user -> user.getName().hashCode(),
            (u1, u2) -> u1.getName().equalsIgnoreCase(u2.getName()),
            (u1, u2) -> u1
    ).collect(Collectors.toList());

通过动态代理

上面这种实现有很多冗余代码,可以考虑使用动态代理实现,首先编写代理接口类,通过EnhancedStream继承Stream接口,增加distinct接口,并让所有返回Stream的接口返回EnhancedStream,这样才能让返回有新的distinct接口可以使用。

public interface EnhancedStream<T> extends Stream<T> 
    EnhancedStream<T> distinct(ToIntFunction<T> hashCode,
                               BiPredicate<T, T> equals,
                               BinaryOperator<T> merger);
    @Override
    EnhancedStream<T> filter(Predicate<? super T> predicate);
    @Override
    <R> EnhancedStream<R> map(
            Function<? super T, ? extends R> mapper);
    @Override
    <R> EnhancedStream<R> flatMap(
            Function<? super T, ? extends Stream<? extends R>> mapper);
    @Override
    EnhancedStream<T> distinct();
    @Override
    EnhancedStream<T> sorted();
    @Override
    EnhancedStream<T> sorted(Comparator<? super T> comparator);
    @Override
    EnhancedStream<T> peek(Consumer<? super T> action);
    @Override
    EnhancedStream<T> limit(long maxSize);
    @Override
    EnhancedStream<T> skip(long n);
    @Override
    EnhancedStream<T> takeWhile(Predicate<? super T> predicate);
    @Override
    EnhancedStream<T> dropWhile(Predicate<? super T> predicate);
    @Override
    EnhancedStream<T> sequential();
    @Override
    EnhancedStream<T> parallel();
    @Override
    EnhancedStream<T> unordered();
    @Override
    EnhancedStream<T> onClose(Runnable closeHandler);


然后,编写代理类EnhancedStreamHandler实现方法代理:

public class EnhancedStreamHandler<T> implements InvocationHandler 
    private Stream<T> delegate;

    public EnhancedStreamHandler(Stream<T> delegate) 
        this.delegate = delegate;
    

    private static final Method ENHANCED_DISTINCT;
    static 
        try 
            ENHANCED_DISTINCT = EnhancedStream.class.getMethod(
                    "distinct", ToIntFunction.class, BiPredicate.class,
                    BinaryOperator.class
            );
         catch (NoSuchMethodException e) 
            throw new Error(e);
        
    

    /**
     * 将EnhancedStream的方法与Stream的方法一一对应
     */
    private static final Map<Method, Method> METHOD_MAP =
            Stream.of(EnhancedStream.class.getMethods())
                    .filter(m -> !m.equals(ENHANCED_DISTINCT))
                    .filter(m -> !Modifier.isStatic(m.getModifiers()))
                    .collect(Collectors.toUnmodifiableMap(
                            Function.identity(),
                            m -> 
                                try 
                                    return Stream.class.getMethod(
                                            m.getName(), m.getParameterTypes());
                                 catch (NoSuchMethodException e) 
                                    throw new Error(e);
                                
                            ));


    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable 
        if (method.equals(ENHANCED_DISTINCT)) 
            //调用方法为扩展方法distinct
            return distinct(
                    (EnhancedStream<T>) proxy,
                    (ToIntFunction<T>) args[0],
                    (BiPredicate<T, T>) args[1],
                    (BinaryOperator<T>) args[2]);
         else if (method.getReturnType() == EnhancedStream.class) 
            //对于返回类型为EnhancedStream的,证明是代理的方法调用,走代理
            Method match = METHOD_MAP.get(method);
            //更相信代理对象为新的Stream
            this.delegate = (Stream) match.invoke(this.delegate, args);
            return proxy;
         else 
            //否则,直接用代理类调用
            return method.invoke(this.delegate, args);
        
    

    private static final class Key<E> 
        private final E e;
        private final ToIntFunction<E> hashCode;
        private final BiPredicate<E, E> equals;

        public Key(E e, ToIntFunction<E> hashCode,
                   BiPredicate<E, E> equals) 
            this.e = e;
            this.hashCode = hashCode;
            this.equals = equals;
        

        @Override
        public int hashCode() 
            return hashCode.applyAsInt(e);
        

        @Override
        public boolean equals(Object obj) 
            if (!(obj instanceof Key)) 
                return false;
            
            @SuppressWarnings("unchecked")
            Key<E> that = (Key<E>) obj;
            return equals.test(this.e, that.e);
        
    

    private EnhancedStream<T> distinct(EnhancedStream<T> proxy,
                                       ToIntFunction<T> hashCode,
                                       BiPredicate<T, T> equals,
                                       BinaryOperator<T> merger) 
        delegate = delegate.collect(Collectors.toMap(
                t -> new Key<>(t, hashCode, equals),
                Function.identity(),
                merger,
                //使用LinkedHashMap,保持入参原始顺序
                LinkedHashMap::new))
                .values()
                .stream();
        return proxy;
    

最后编写工厂类,生成EnhancedStream代理类:

public class EnhancedStreamFactory 
    public static <E> EnhancedStream<E> newEnhancedStream(Stream<E> stream) 
        return (EnhancedStream<E>) Proxy.newProxyInstance(
                //必须用EnhancedStream的classLoader,不能用Stream的,因为Stream是jdk的类,ClassLoader是rootClassLoader
                EnhancedStream.class.getClassLoader(),
                //代理接口
                new Class<?>[] EnhancedStream.class,
                //代理类
                new EnhancedStreamHandler<>(stream)
        );
    

这样,代码看上去更优雅了,就算 JDK 以后扩展更多方法,这里也可不用修改

以上是关于增强 Stream 接口的 distinct 方法的一些思考的主要内容,如果未能解决你的问题,请参考以下文章

java9新特性-13-增强的 Stream API

精细篇Java8强大的stream API接口大全(代码优雅之道)

任意键上的Java Lambda Stream Distinct()? [复制]

33JDK1.8新特性(Lambda表达式Stream流)

java8中Stream数据流

2w字合集 | 函数式编程—Stream流