Java 中每个键的线程池

Thread Pool per key in Java

假设您有一个包含 n x m 个单元格的 G 网格,其中 nm 很大。 此外,假设我们有许多任务,其中每个任务属于 G 中的一个单元格,并且应该并行执行(在线程池或其他资源池中)。

但是,属于同一个单元格的任务必须串行完成,也就是说,它应该等待同一个单元格中的前一个任务完成。

我该如何解决这个问题? 我已经搜索并使用了几个线程池(Executors、Thread),但没有成功。

最小工作示例

import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class MWE {

    public static void main(String[] args) {
        ExecutorService threadPool = Executors.newFixedThreadPool(16);
        Random r = new Random();

        for (int i = 0; i < 10000; i++) {
            int nx = r.nextInt(10);
            int ny = r.nextInt(10);

            Runnable task = new Runnable() { 
                public void run() { 
                  try {
                    System.out.println("Task is running"); 
                    Thread.sleep(1000);
                  } catch (InterruptedException e) {
                    e.printStackTrace();
                  }
                } 
            };

            threadPool.submit(new Thread(task)); // Should use nx,ny here somehow
        }
    }

}

如果我没看错,你想在 Y 个队列(Y 比 X 小得多)中执行 X 个任务(X 很大)。
Java 8 有 CompletableFuture class,表示一个(异步)计算。基本上,它是 Java 对 Promise 的实现。以下是如何组织计算链(省略通用类型):

// start the queue with a "completed" task
CompletableFuture queue = CompletableFuture.completedFuture(null);  
// append a first task to the queue
queue = queue.thenRunAsync(() -> System.out.println("first task running"));  
// append a second task to the queue
queue = queue.thenRunAsync(() -> System.out.println("second task running"));
// ... and so on

当您使用 thenRunAsync(Runnable) 时,任务将使用线程池执行(还有其他可能性 - 请参阅 API docs)。您也可以提供自己的线程池。 您可以创建 Y 个这样的链(可能在某些 table 中保留对它们的引用)。

这是 java 世界中像 Akka 这样的系统使得 sense.If X 和 Y 都很大,您可能希望使用消息传递机制来处理它们,而不是将它们集中在一起巨大的回调和期货链。一个演员有要完成的任务列表,并被交给一个单元格,演员最终会计算结果并保存它。如果中间步骤出现问题,那不是世界末日。

带有同步块的回调机制可以在这里有效地工作。 我之前回答过类似的问题。 有一些限制(参见链接的答案),但它足够简单,可以跟踪正在发生的事情(良好的可维护性)。 我已经调整了源代码并使其更适合您的情况,其中大多数任务将并行执行 (因为 nm 很大),但有时必须是连续的(当任务针对网格 G 中的同一点时)。

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.ReentrantLock;

// Adapted from 
public class GridTaskExecutor {

    public static void main(String[] args) {

        final int maxTasks = 10_000;
        final CountDownLatch tasksDone = new CountDownLatch(maxTasks);
        ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(16);
        try {
            GridTaskExecutor gte = new GridTaskExecutor(executor); 
            Random r = new Random();

            for (int i = 0; i < maxTasks; i++) {

                final int nx = r.nextInt(10);
                final int ny = r.nextInt(10);

                Runnable task = new Runnable() { 
                    public void run() { 
                        try {
                            // System.out.println("Task " + nx + " / " + ny + " is running");
                            Thread.sleep(1);
                        } catch (Exception e) {
                            e.printStackTrace();
                        } finally {
                            tasksDone.countDown();
                        }
                    } 
                };
                gte.addTask(task, nx, ny);
            }
            tasksDone.await();
            System.out.println("All tasks done, task points remaining: " + gte.size());
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            executor.shutdownNow();
        }
    }

    private final Executor executor;
    private final Map<Long, List<CallbackPointTask>> tasksWaiting = new HashMap<>();
    // make lock fair so that adding and removing tasks is balanced.
    private final ReentrantLock lock = new ReentrantLock(true);

    public GridTaskExecutor(Executor executor) {
        this.executor = executor;
    }

    public void addTask(Runnable r, int x, int y) {

        Long point = toPoint(x, y);
        CallbackPointTask pr = new CallbackPointTask(point, r);
        boolean runNow = false;
        lock.lock();
        try {
            List<CallbackPointTask> pointTasks = tasksWaiting.get(point);
            if (pointTasks == null) {
                if (tasksWaiting.containsKey(point)) {
                    pointTasks = new LinkedList<CallbackPointTask>();
                    pointTasks.add(pr);
                    tasksWaiting.put(point, pointTasks);
                } else {
                    tasksWaiting.put(point, null);
                    runNow = true;
                }
            } else {
                pointTasks.add(pr);
            }
        } finally {
            lock.unlock();
        }
        if (runNow) {
            executor.execute(pr);
        }
    }

    private void taskCompleted(Long point) {

        lock.lock();
        try {
            List<CallbackPointTask> pointTasks = tasksWaiting.get(point);
            if (pointTasks == null || pointTasks.isEmpty()) {
                tasksWaiting.remove(point);
            } else {
                System.out.println(Arrays.toString(fromPoint(point)) + " executing task " + pointTasks.size());
                executor.execute(pointTasks.remove(0));
            }
        } finally {
            lock.unlock();
        }
    }

    // for a general callback-task, see 
    private class CallbackPointTask implements Runnable {

        final Long point;
        final Runnable original;

        CallbackPointTask(Long point, Runnable original) {
            this.point = point;
            this.original = original;
        }

        @Override
        public void run() {

            try {
                original.run();
            } finally {
                taskCompleted(point);
            }
        }
    }

    /** Amount of points with tasks. */ 
    public int size() {

        int l = 0;
        lock.lock();
        try {
            l = tasksWaiting.size(); 
        } finally {
            lock.unlock();
        }
        return l;
    }

    // 
    public static long toPoint(int x, int y) {
        return (((long)x) << 32) | (y & 0xffffffffL);
    }

    public static int[] fromPoint(long p) {
        return new int[] {(int)(p >> 32), (int)p };
    }

}

您可以创建 n Executors.newFixedThreadPool(1) 个列表。 然后通过哈希函数提交给相应的线程。 前任。 threadPool[key%n].submit(new Thread(task)).

这个图书馆应该做的工作:https://github.com/jano7/executor

int maxTasks = 16;
ExecutorService threadPool = Executors.newFixedThreadPool(maxTasks);
KeySequentialBoundedExecutor executor = new KeySequentialBoundedExecutor(maxTasks, threadPool);

Random r = new Random();

for (int i = 0; i < 10000; i++) {
    int nx = r.nextInt(10);
    int ny = r.nextInt(10);

    Runnable task = new Runnable() {

    public void run() { 
        try {
            System.out.println("Task is running"); 
            Thread.sleep(1000);
        } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } 
    };

    executor.execute(new KeyRunnable<>((ny * 10) + nx, task));
}

下面给出的 Scala 示例演示了映射中的键如何并行执行以及键的值如何串行执行。如果您想在 Java 中尝试它,请将其更改为 Java 语法(Scala 使用 JVM 库)。基本上将未来的任务链接起来,让它们按顺序执行。

import java.util.concurrent.{CompletableFuture, ExecutorService, Executors, Future, TimeUnit}
import scala.collection.concurrent.TrieMap
import scala.collection.mutable.ListBuffer
import scala.util.Random

/**
 * For a given Key-Value pair with tasks as values, demonstrates sequential execution of tasks
 * within a key and parallel execution across keys.
 */
object AsyncThreads {

  val cachedPool: ExecutorService = Executors.newCachedThreadPool
  var initialData: Map[String, ListBuffer[Int]] = Map()
  var processedData: TrieMap[String, ListBuffer[Int]] = TrieMap()
  var runningTasks: TrieMap[String, CompletableFuture[Void]] = TrieMap()

  /**
   * synchronous execution across keys and values
   */
  def processSync(key: String, value: Int, initialSleep: Long) = {
    Thread.sleep(initialSleep)
    if (key.equals("key_0")) {
      println(s"${Thread.currentThread().getName} -> sleep: $initialSleep. Inserting key_0 -> $value")
    }
    processedData.getOrElseUpdate(key, new ListBuffer[Int]).addOne(value)
  }

  /**
   * parallel execution across keys
   */
  def processASync(key: String, value: Int, initialSleep: Long) = {
    val task: Runnable = () => {
      processSync(key, value, initialSleep)
    }

    // 1. Chain the futures for sequential execution within a key
    val prevFuture = runningTasks.getOrElseUpdate(key, CompletableFuture.completedFuture(null))
    runningTasks.put(key, prevFuture.thenRunAsync(task, cachedPool))

    // 2. Parallel execution across keys and values
    // cachedPool.submit(task)
  }

  def process(key: String, value: Int, initialSleep: Int): Unit = {
    //processSync(key, value, initialSleep) // synchronous execution across keys and values
    processASync(key, value, initialSleep) // parallel execution across keys
  }

  def main(args: Array[String]): Unit = {

    checkDiff()

    0.to(9).map(kIndex => {
      var key = "key_" + kIndex
      var values = ListBuffer[Int]()
      initialData += (key -> values)
      1.to(10).map(vIndex => {
        values += kIndex * 10 + vIndex
      })
    })

    println(s"before data:$initialData")

    initialData.foreach(entry => {
      entry._2.foreach(value => {
        process(entry._1, value, Random.between(0, 100))
      })
    })

    cachedPool.awaitTermination(5, TimeUnit.SECONDS)
    println(s"after data:$processedData")

    println("diff: " + (initialData.toSet diff processedData.toSet).toMap)
    cachedPool.shutdown()
  }

  def checkDiff(): Unit = {
    var a1: TrieMap[String, List[Int]] = new TrieMap()
    a1.put("one", List(1, 2, 3, 4, 5))
    a1.put("two", List(11, 12, 13, 14, 15))

    var a2: TrieMap[String, List[Int]] = new TrieMap()
    a2.put("one", List(2, 1, 3, 4, 5))
    a2.put("two", List(11, 12, 13, 14, 15))


    println("a1: " + a1)
    println("a2: " + a2)

    println("check.diff: " + (a1.toSet diff a2.toSet).toMap)
  }
}