如何限制 kotlin 协程的最大并发性

Posted

技术标签:

【中文标题】如何限制 kotlin 协程的最大并发性【英文标题】:how to cap kotlin coroutines maximum concurrency 【发布时间】:2018-05-21 01:06:48 【问题描述】:

我有一个序列(来自 File.walkTopDown),我需要在每个序列上运行一个长时间运行的操作。我想使用 Kotlin 最佳实践/协程,但我要么没有并行性,要么并行性太多,并遇到“打开文件太多”的 IO 错误。

File("/Users/me/Pictures/").walkTopDown()
    .onFail  file, ex -> println("ERROR: $file caused $ex") 
    .filter  ... only big images... 
    .map  file ->
        async  // I *think* I want async and not "launch"...
            ImageProcessor.fromFile(file)
        
    

这似乎不是并行运行的,而且我的多核 CPU 永远不会超过 1 个 CPU 的价值。有没有办法使用协程来运行“NumberOfCores 并行操作”的延迟作业?

我查看了Multithreading using Kotlin Coroutines,它首先创建了所有作业,然后加入它们,但这意味着在繁重的处理加入步骤之前完成序列/文件树遍历,这似乎......不确定!将其拆分为收集和处理步骤意味着收集可以在处理之前运行。

val jobs = ... the Sequence above...
    .toSet()
println("Found $jobs.size")
jobs.forEach  it.await() 

【问题讨论】:

【参考方案1】:

为什么不使用asFlow() 运算符,然后使用flatMapMerge

someCoroutineScope.launch(Dispatchers.Default) 
    File("/Users/me/Pictures/").walkTopDown()
        .asFlow()
        .filter  ... only big images... 
        .flatMapMerge(concurrencyLimit)  file ->
            flow 
                emit(runInterruptable  ImageProcessor.fromFile(file) )
            
        .catch  ... 
        .collect()
    

然后您可以限制同时打开的文件,同时仍同时处理它们。

【讨论】:

不错!我认为这行不通,因为我将controls the number of in-flight flows 读为“它可以合并多少个流”(在我的情况下,我只处理一个),但你现在让我认为它可能意味着“多少发出它可以立即咀嚼"【参考方案2】:

这不是针对您的问题,但它确实回答了“如何限制 kotlin 协程最大并发性”的问题。

编辑:从 kotlinx.coroutines 1.6.0 (https://github.com/Kotlin/kotlinx.coroutines/issues/2919) 开始,您可以使用 limitedParallelism,例如Dispatchers.IO.limitedParallelism(123).

旧解决方案:起初我想使用newFixedThreadPoolContext,但1)it's deprecated 和2)它会使用线程,我认为这不是必要的或可取的(与Executors.newFixedThreadPool().asCoroutineDispatcher() 相同)。这个解决方案可能存在我使用Semaphore 不知道的缺陷,但它非常简单:

import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit

/**
 * Maps the inputs using [transform] at most [maxConcurrency] at a time until all Jobs are done.
 */
suspend fun <TInput, TOutput> Iterable<TInput>.mapConcurrently(
    maxConcurrency: Int,
    transform: suspend (TInput) -> TOutput,
) = coroutineScope 
    val gate = Semaphore(maxConcurrency)
    this@mapConcurrently.map 
        async 
            gate.withPermit 
                transform(it)
            
        
    .awaitAll()


测试(抱歉,它使用 Spek、hamcrest 和 kotlin 测试):

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestCoroutineDispatcher
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.greaterThanOrEqualTo
import org.hamcrest.Matchers.lessThanOrEqualTo
import org.spekframework.spek2.Spek
import org.spekframework.spek2.style.specification.describe
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals

@OptIn(ExperimentalCoroutinesApi::class)
object AsyncHelpersKtTest : Spek(
    val actionDelay: Long = 1_000 // arbitrary; obvious if non-test dispatcher is used on accident
    val testDispatcher = TestCoroutineDispatcher()

    afterEachTest 
        // Clean up the TestCoroutineDispatcher to make sure no other work is running.
        testDispatcher.cleanupTestCoroutines()
    

    describe("mapConcurrently") 
        it("should run all inputs concurrently if maxConcurrency >= size") 
            val concurrentJobCounter = AtomicInteger(0)
            val inputs = IntRange(1, 2).toList()
            val maxConcurrency = inputs.size

            // https://github.com/Kotlin/kotlinx.coroutines/issues/1266 has useful info & examples
            runBlocking(testDispatcher) 
                print("start runBlocking $coroutineContext\n")

                // We have to run this async so that the code afterwards can advance the virtual clock
                val job = launch 
                    testDispatcher.pauseDispatcher 
                        val result = inputs.mapConcurrently(maxConcurrency) 
                            print("action $it $coroutineContext\n")

                            // Sanity check that we never run more in parallel than max
                            assertThat(concurrentJobCounter.addAndGet(1), lessThanOrEqualTo(maxConcurrency))

                            // Allow for virtual clock adjustment
                            delay(actionDelay)

                            // Sanity check that we never run more in parallel than max
                            assertThat(concurrentJobCounter.getAndAdd(-1), lessThanOrEqualTo(maxConcurrency))
                            print("action $it after delay $coroutineContext\n")

                            it
                        

                        // Order is not guaranteed, thus a Set
                        assertEquals(inputs.toSet(), result.toSet())
                        print("end mapConcurrently $coroutineContext\n")
                    
                
                print("before advanceTime $coroutineContext\n")

                // Start the coroutines
                testDispatcher.advanceTimeBy(0)
                assertEquals(inputs.size, concurrentJobCounter.get(), "All jobs should have been started")

                testDispatcher.advanceTimeBy(actionDelay)
                print("after advanceTime $coroutineContext\n")
                assertEquals(0, concurrentJobCounter.get(), "All jobs should have finished")
                job.join()
            
        

        it("should run one at a time if maxConcurrency = 1") 
            val concurrentJobCounter = AtomicInteger(0)
            val inputs = IntRange(1, 2).toList()
            val maxConcurrency = 1

            runBlocking(testDispatcher) 
                val job = launch 
                    testDispatcher.pauseDispatcher 
                        inputs.mapConcurrently(maxConcurrency) 
                            assertThat(concurrentJobCounter.addAndGet(1), lessThanOrEqualTo(maxConcurrency))
                            delay(actionDelay)
                            assertThat(concurrentJobCounter.getAndAdd(-1), lessThanOrEqualTo(maxConcurrency))
                            it
                        
                    
                

                testDispatcher.advanceTimeBy(0)
                assertEquals(1, concurrentJobCounter.get(), "Only one job should have started")

                val elapsedTime = testDispatcher.advanceUntilIdle()
                print("elapsedTime=$elapsedTime")
                assertThat(
                    "Virtual time should be at least as long as if all jobs ran sequentially",
                    elapsedTime,
                    greaterThanOrEqualTo(actionDelay * inputs.size)
                )
                job.join()
            
        

        it("should handle cancellation") 
            val jobCounter = AtomicInteger(0)
            val inputs = IntRange(1, 2).toList()
            val maxConcurrency = 1

            runBlocking(testDispatcher) 
                val job = launch 
                    testDispatcher.pauseDispatcher 
                        inputs.mapConcurrently(maxConcurrency) 
                            jobCounter.addAndGet(1)
                            delay(actionDelay)
                            it
                        
                    
                

                testDispatcher.advanceTimeBy(0)
                assertEquals(1, jobCounter.get(), "Only one job should have started")

                job.cancel()
                testDispatcher.advanceUntilIdle()
                assertEquals(1, jobCounter.get(), "Only one job should have run")
                job.join()
            
        
    
)

根据https://play.kotlinlang.org/hands-on/Introduction%20to%20Coroutines%20and%20Channels/09_Testing,您可能还需要调整编译器参数以运行测试:

compileTestKotlin 
    kotlinOptions 
        // Needed for runBlocking test coroutine dispatcher?
        freeCompilerArgs += "-Xuse-experimental=kotlin.Experimental"
        freeCompilerArgs += "-Xopt-in=kotlin.RequiresOptIn"
    

testImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.4.1'

【讨论】:

拯救了我的一天!感谢分享,特别感谢测试。必须从那里学到很多东西【参考方案3】:

这不会保留投影的顺序,但会将吞吐量限制为最多 maxDegreeOfParallelism。按照您认为合适的方式扩展和扩展。

suspend fun <TInput, TOutput> (Collection<TInput>).inParallel(
        maxDegreeOfParallelism: Int,
        action: suspend CoroutineScope.(input: TInput) -> TOutput
): Iterable<TOutput> = coroutineScope 

    val list = this@inParallel

    if (list.isEmpty())
        return@coroutineScope listOf<TOutput>()

    val brake = Channel<Unit>(maxDegreeOfParallelism)
    val output = Channel<TOutput>()
    val counter = AtomicInteger(0)

    this.launch 

        repeat(maxDegreeOfParallelism) 
            brake.send(Unit)
        

        for (input in list) 

            val task = this.async 
                action(input)
            

            this.launch 
                val result = task.await()
                output.send(result)
                val completed = counter.incrementAndGet()
                if (completed == list.size) 
                    output.close()
                 else brake.send(Unit)
            

            brake.receive()
        
    

    val results = mutableListOf<TOutput>()
    for (item in output) 
        results.add(item)
    

    return@coroutineScope results

示例用法:

val output = listOf(1, 2, 3).inParallel(2) 
    it + 1
 // Note that output may not be in same order as list.

【讨论】:

【参考方案4】:

这会将协程限制为工人。我建议看https://www.youtube.com/watch?v=3WGM-_MnPQA

package com.example.workers

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.produce
import kotlin.system.measureTimeMillis

class ChannellibgradleApplication

fun main(args: Array<String>) 
    var myList = mutableListOf<Int>(3000,1200,1400,3000,1200,1400,3000)
    runBlocking 
        var myChannel = produce(CoroutineName("MyInts")) 
            myList.forEach  send(it) 
        

        println("Starting coroutineScope  ")
        var time = measureTimeMillis 
            coroutineScope 
                var workers = 2
                repeat(workers)
                
                    launch(CoroutineName("Sleep 1"))  theHardWork(myChannel) 
                
            
        
        println("Ending coroutineScope  $time ms")
    


suspend fun theHardWork(channel : ReceiveChannel<Int>) 

    for(m in channel) 
        println("Starting Sleep $m")
        delay(m.toLong())
        println("Ending Sleep $m")
    

【讨论】:

【参考方案5】:

我让它与一个频道一起工作。但也许我对你的方式有点多余?

val pipe = ArrayChannel<Deferred<ImageFile>>(20)
launch 
    while (!(pipe.isEmpty && pipe.isClosedForSend)) 
        imageFiles.add(pipe.receive().await())
    
    println("pipe closed")

File("/Users/me/").walkTopDown()
        .onFail  file, ex -> println("ERROR: $file caused $ex") 
        .forEach  pipe.send(async  ImageFile.fromFile(it) ) 
pipe.close()

【讨论】:

【参考方案6】:

您的第一个 sn-p 的问题在于它根本不运行 - 请记住,Sequence 是惰性的,您必须使用终端操作,例如 toSet()forEach()。此外,您需要通过构造newFixedThreadPoolContext 上下文并在async 中使用它来限制可用于该任务的线程数:

val pictureContext = newFixedThreadPoolContext(nThreads = 10, name = "reading pictures in parallel")

File("/Users/me/Pictures/").walkTopDown()
    .onFail  file, ex -> println("ERROR: $file caused $ex") 
    .filter  ... only big images... 
    .map  file ->
        async(pictureContext) 
            ImageProcessor.fromFile(file)
        
    
    .toList()
    .forEach  it.await() 

编辑: 您必须使用终端操作员 (toList) befor 等待结果

【讨论】:

我虽然可以,但它似乎仍然按顺序处理最终的 forEach。例如。 .map file -> async(CommonPool) println("start") val img = ImageFile.fromFile(file) println("end") img .forEach imageFiles.add(it.await()) if ( Math.random() > 0.999) imageFiles.save() 哦,snap,你是对的。现在我认为没有办法用序列来做到这一点。编辑了答案 值得注意的是,使用有限的线程池会限制并行性但不会限制并发性,这意味着如果ImageProcessor.fromFile 是一个挂起函数(不会阻塞),您仍然可以处理多个文件,这可能是不是你想要的。

以上是关于如何限制 kotlin 协程的最大并发性的主要内容,如果未能解决你的问题,请参考以下文章

并发编程之协程

Kotlin协程的原理,没有说得比AndroidDeveloper官方更显浅的了

Kotlin协程的原理,没有说得比AndroidDeveloper官方更显浅的了

发现不一样的Kotlin多方位处理协程的异常

发现不一样的Kotlin多方位处理协程的异常

Kotlin 协程源码解析