在 java 中使用线程优化程序
Optimizing program using threads in java
我的目标是使用 Java 中的 ExecutorService
计算二叉树中元素的总和,然后使用 CompletionService
收集每个任务的结果。
用户给出树的高度,并行度应该开始的级别,以及要使用的线程数。我知道 ExecutorService
应该产生与用户给定的线程数量完全相同的线程,并且完成服务应该在 preProcess
方法中产生恰好 N 个任务,其中 N 是 2^(level of并行性),因为在某个级别 n,我们将有 2^n 个节点。
我的问题是我不知道如何从给定高度开始遍历树以及如何使用 CompletionService
在我的 postProcess
方法中收集结果。此外,每次产生新任务时,总任务数会增加一个,每次 CompletionService
返回结果时,任务数应减少一个。
我能够在 processTreeParallel
函数中使用 CompletionService
,但我真的不明白如何在我的 postProcess
方法中使用它。
这是我的代码:
import java.util.concurrent.*;
public class TreeCalculation {
// tree level to go parallel
int levelParallel;
// total number of generated tasks
long totalTasks;
// current number of open tasks
long nTasks;
// total height of tree
int height;
// Executors
ExecutorService exec;
CompletionService<Long> cs;
TreeCalculation(int height, int levelParallel) {
this.height = height;
this.levelParallel = levelParallel;
}
void incrementTasks() {
++nTasks;
++totalTasks;
}
void decrementTasks() {
--nTasks;
}
long getNTasks() {
return nTasks;
}
// Where the ExecutorService should be initialized
// with a specific threadCount
void preProcess(int threadCount) {
exec = Executors.newFixedThreadPool(threadCount);
cs = new ExecutorCompletionService<Long>(exec);
nTasks = 0;
totalTasks = 0;
}
// Where the CompletionService should collect the results;
long postProcess() {
long result = 0;
return result;
}
public static void main(String[] args) {
if (args.length != 3) {
System.out.println(
"usage: java Tree treeHeight levelParallel nthreads\n");
return;
}
int height = Integer.parseInt(args[0]);
int levelParallel = Integer.parseInt(args[1]);
int threadCount = Integer.parseInt(args[2]);
TreeCalculation tc = new TreeCalculation(height, levelParallel);
// generate balanced binary tree
Tree t = Tree.genTree(height, height);
//System.gc();
// traverse sequential
long t0 = System.nanoTime();
long p1 = t.processTree();
double t1 = (System.nanoTime() - t0) * 1e-9;
t0 = System.nanoTime();
tc.preProcess(threadCount);
long p2 = t.processTreeParallel(tc);
p2 += tc.postProcess();
double t2 = (System.nanoTime() - t0) * 1e-9;
long ref = (Tree.counter * (Tree.counter + 1)) / 2;
if (p1 != ref)
System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
if (p1 != p2)
System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
if (tc.totalTasks != (2 << levelParallel)) {
System.out.printf("ERROR: ntasks %d != %d\n",
2 << levelParallel, tc.totalTasks);
}
// print timing
System.out.printf("tree height: %2d "
+ "sequential: %.6f "
+ "parallel with %3d threads and %6d tasks: %.6f "
+ "speedup: %.3f count: %d\n",
height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
}
}
// ============================================================================
class Tree {
static long counter; // counter for consecutive node numbering
int level; // node level
long value; // node value
Tree left; // left child
Tree right; // right child
// constructor
Tree(long value) {
this.value = value;
}
// generate a balanced binary tree of depth k
static Tree genTree(int k, int height) {
if (k < 0) {
return null;
} else {
Tree t = new Tree(++counter);
t.level = height - k;
t.left = genTree(k - 1, height);
t.right = genTree(k - 1, height);
return t;
}
}
// ========================================================================
// traverse a tree sequentially
long processTree() {
return value
+ ((left == null) ? 0 : left.processTree())
+ ((right == null) ? 0 : right.processTree());
}
// ========================================================================
// traverse a tree parallel
// This is where I was able to use the CompletionService
long processTreeParallel(TreeCalculation tc) {
tc.totalTasks = 0;
for(long i =0; i<(long)Math.pow(tc.levelParallel, 2); i++)
{
tc.incrementTasks();
tc.cs.submit(new Callable<Long>(){
@Override
public Long call() throws Exception {
return processTree();
}
});
}
Long result = Long.valueOf(0);
for(int i=0; i<(long)Math.pow(2,tc.levelParallel); i++) {
try{
result += tc.cs.take().get();
tc.decrementTasks();
}catch(Exception e){}
}
return result;
}
}
我假设您希望通过 运行 并行计算的不同部分来加速 "tree processing" 任务。
据我所知,您当前的解决方案提交了多个可调用项,每个可调用项都执行完全相同的操作,即对整个树进行处理。这意味着您正在多次执行相同的总体任务。这不太可能是您想要的。相反,您可能希望将整个任务拆分为几个部分任务。部分任务应该不重叠并且一起覆盖整个任务。对于这些部分任务,您可以并行执行它们,并以某种方式收集结果。
既然你是在树上执行,你就得想办法把树的处理分成合适的部分。最好以一种更容易设计和实现的方式,以及大多数合适大小的块,这样并行化将更有效。
现在的并行计算也是错误的,从运行程序可以看出,输入如下:
java TreeCalculation 10 2 4
对Math.pow(tc.levelParallel, 2)
和Math.pow(2,tc.levelParallel)
的调用也不同。
还要注意内存一致性问题。我一眼就没有看到,虽然你在这里和那里变异记忆。
这里的基本思想是遍历树,并像在 processTree
方法中那样计算结果。但是,一旦达到应该开始并行计算的级别(levelParallel
),您就会生成一个 实际上 在内部调用 processTree
的任务.这将处理树的剩余部分。
processTreeParallel 0
/ \
/ \
processTreeParallel 1 2
/ \ / \
processTreeParallel 3 4 5 6 <- levelParallel
| | | |
processTree call for each: v v v v
+---------------+
tasks for executor: |T T T T |
+---------------+
completion service |
fetches tasks and v
sums them up: T+T+T+T -> result
然后您必须添加由 processTreeParallel
方法的 sequential 部分计算的结果,以及由完成总结的任务结果服务。
processTreeParallel
方法可以这样实现:
long processTreeParallel(TreeCalculation tc)
{
if (level < tc.levelParallel)
{
long leftResult = left.processTreeParallel(tc);
long rightResult = right.processTreeParallel(tc);
return value + leftResult + rightResult;
}
tc.incrementTasks();
tc.cs.submit(new Callable<Long>()
{
@Override
public Long call() throws Exception
{
return processTree();
}
});
return 0;
}
完整的程序显示在这里:
import java.util.concurrent.*;
public class TreeCalculation
{
// tree level to go parallel
int levelParallel;
// total number of generated tasks
long totalTasks;
// current number of open tasks
long nTasks;
// total height of tree
int height;
// Executors
ExecutorService exec;
CompletionService<Long> cs;
TreeCalculation(int height, int levelParallel)
{
this.height = height;
this.levelParallel = levelParallel;
}
void incrementTasks()
{
++nTasks;
++totalTasks;
}
void decrementTasks()
{
--nTasks;
}
long getNTasks()
{
return nTasks;
}
// Where the ExecutorService should be initialized
// with a specific threadCount
void preProcess(int threadCount)
{
exec = Executors.newFixedThreadPool(threadCount);
cs = new ExecutorCompletionService<Long>(exec);
nTasks = 0;
totalTasks = 0;
}
// Where the CompletionService should collect the results;
long postProcess()
{
exec.shutdown();
long result = 0;
for (int i = 0; i < (long) Math.pow(2, levelParallel); i++)
{
try
{
result += cs.take().get();
decrementTasks();
}
catch (Exception e)
{
e.printStackTrace();
}
}
return result;
}
public static void main(String[] args)
{
int height = 22;
int levelParallel = 3;
int threadCount = 4;
if (args.length != 3)
{
System.out.println(
"usage: java Tree treeHeight levelParallel nthreads\n");
System.out.println("Using default values for test");
}
else
{
height = Integer.parseInt(args[0]);
levelParallel = Integer.parseInt(args[1]);
threadCount = Integer.parseInt(args[2]);
}
TreeCalculation tc = new TreeCalculation(height, levelParallel);
// generate balanced binary tree
Tree t = Tree.genTree(height, height);
// traverse sequential
long t0 = System.nanoTime();
long p1 = t.processTree();
double t1 = (System.nanoTime() - t0) * 1e-9;
t0 = System.nanoTime();
tc.preProcess(threadCount);
long p2 = t.processTreeParallel(tc);
p2 += tc.postProcess();
double t2 = (System.nanoTime() - t0) * 1e-9;
long ref = (Tree.counter * (Tree.counter + 1)) / 2;
if (p1 != ref)
System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
if (p1 != p2)
System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
if (tc.totalTasks != (1 << levelParallel))
{
System.out.printf("ERROR: ntasks %d != %d\n", 1 << levelParallel,
tc.totalTasks);
}
// print timing
System.out.printf("tree height: %2d\n"
+ "sequential: %.6f\n"
+ "parallel with %3d threads and %6d tasks: %.6f\n"
+ "speedup: %.3f count: %d\n",
height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
}
}
// ============================================================================
class Tree
{
static long counter; // counter for consecutive node numbering
int level; // node level
long value; // node value
Tree left; // left child
Tree right; // right child
// constructor
Tree(long value)
{
this.value = value;
}
// generate a balanced binary tree of depth k
static Tree genTree(int k, int height)
{
if (k < 0)
{
return null;
}
Tree t = new Tree(++counter);
t.level = height - k;
t.left = genTree(k - 1, height);
t.right = genTree(k - 1, height);
return t;
}
// ========================================================================
// traverse a tree sequentially
long processTree()
{
return value
+ ((left == null) ? 0 : left.processTree())
+ ((right == null) ? 0 : right.processTree());
}
// ========================================================================
// traverse a tree parallel
long processTreeParallel(TreeCalculation tc)
{
if (level < tc.levelParallel)
{
long leftResult = left.processTreeParallel(tc);
long rightResult = right.processTreeParallel(tc);
return value + leftResult + rightResult;
}
tc.incrementTasks();
tc.cs.submit(new Callable<Long>()
{
@Override
public Long call() throws Exception
{
return processTree();
}
});
return 0;
}
}
我的目标是使用 Java 中的 ExecutorService
计算二叉树中元素的总和,然后使用 CompletionService
收集每个任务的结果。
用户给出树的高度,并行度应该开始的级别,以及要使用的线程数。我知道 ExecutorService
应该产生与用户给定的线程数量完全相同的线程,并且完成服务应该在 preProcess
方法中产生恰好 N 个任务,其中 N 是 2^(level of并行性),因为在某个级别 n,我们将有 2^n 个节点。
我的问题是我不知道如何从给定高度开始遍历树以及如何使用 CompletionService
在我的 postProcess
方法中收集结果。此外,每次产生新任务时,总任务数会增加一个,每次 CompletionService
返回结果时,任务数应减少一个。
我能够在 processTreeParallel
函数中使用 CompletionService
,但我真的不明白如何在我的 postProcess
方法中使用它。
这是我的代码:
import java.util.concurrent.*;
public class TreeCalculation {
// tree level to go parallel
int levelParallel;
// total number of generated tasks
long totalTasks;
// current number of open tasks
long nTasks;
// total height of tree
int height;
// Executors
ExecutorService exec;
CompletionService<Long> cs;
TreeCalculation(int height, int levelParallel) {
this.height = height;
this.levelParallel = levelParallel;
}
void incrementTasks() {
++nTasks;
++totalTasks;
}
void decrementTasks() {
--nTasks;
}
long getNTasks() {
return nTasks;
}
// Where the ExecutorService should be initialized
// with a specific threadCount
void preProcess(int threadCount) {
exec = Executors.newFixedThreadPool(threadCount);
cs = new ExecutorCompletionService<Long>(exec);
nTasks = 0;
totalTasks = 0;
}
// Where the CompletionService should collect the results;
long postProcess() {
long result = 0;
return result;
}
public static void main(String[] args) {
if (args.length != 3) {
System.out.println(
"usage: java Tree treeHeight levelParallel nthreads\n");
return;
}
int height = Integer.parseInt(args[0]);
int levelParallel = Integer.parseInt(args[1]);
int threadCount = Integer.parseInt(args[2]);
TreeCalculation tc = new TreeCalculation(height, levelParallel);
// generate balanced binary tree
Tree t = Tree.genTree(height, height);
//System.gc();
// traverse sequential
long t0 = System.nanoTime();
long p1 = t.processTree();
double t1 = (System.nanoTime() - t0) * 1e-9;
t0 = System.nanoTime();
tc.preProcess(threadCount);
long p2 = t.processTreeParallel(tc);
p2 += tc.postProcess();
double t2 = (System.nanoTime() - t0) * 1e-9;
long ref = (Tree.counter * (Tree.counter + 1)) / 2;
if (p1 != ref)
System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
if (p1 != p2)
System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
if (tc.totalTasks != (2 << levelParallel)) {
System.out.printf("ERROR: ntasks %d != %d\n",
2 << levelParallel, tc.totalTasks);
}
// print timing
System.out.printf("tree height: %2d "
+ "sequential: %.6f "
+ "parallel with %3d threads and %6d tasks: %.6f "
+ "speedup: %.3f count: %d\n",
height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
}
}
// ============================================================================
class Tree {
static long counter; // counter for consecutive node numbering
int level; // node level
long value; // node value
Tree left; // left child
Tree right; // right child
// constructor
Tree(long value) {
this.value = value;
}
// generate a balanced binary tree of depth k
static Tree genTree(int k, int height) {
if (k < 0) {
return null;
} else {
Tree t = new Tree(++counter);
t.level = height - k;
t.left = genTree(k - 1, height);
t.right = genTree(k - 1, height);
return t;
}
}
// ========================================================================
// traverse a tree sequentially
long processTree() {
return value
+ ((left == null) ? 0 : left.processTree())
+ ((right == null) ? 0 : right.processTree());
}
// ========================================================================
// traverse a tree parallel
// This is where I was able to use the CompletionService
long processTreeParallel(TreeCalculation tc) {
tc.totalTasks = 0;
for(long i =0; i<(long)Math.pow(tc.levelParallel, 2); i++)
{
tc.incrementTasks();
tc.cs.submit(new Callable<Long>(){
@Override
public Long call() throws Exception {
return processTree();
}
});
}
Long result = Long.valueOf(0);
for(int i=0; i<(long)Math.pow(2,tc.levelParallel); i++) {
try{
result += tc.cs.take().get();
tc.decrementTasks();
}catch(Exception e){}
}
return result;
}
}
我假设您希望通过 运行 并行计算的不同部分来加速 "tree processing" 任务。
据我所知,您当前的解决方案提交了多个可调用项,每个可调用项都执行完全相同的操作,即对整个树进行处理。这意味着您正在多次执行相同的总体任务。这不太可能是您想要的。相反,您可能希望将整个任务拆分为几个部分任务。部分任务应该不重叠并且一起覆盖整个任务。对于这些部分任务,您可以并行执行它们,并以某种方式收集结果。
既然你是在树上执行,你就得想办法把树的处理分成合适的部分。最好以一种更容易设计和实现的方式,以及大多数合适大小的块,这样并行化将更有效。
现在的并行计算也是错误的,从运行程序可以看出,输入如下:
java TreeCalculation 10 2 4
对Math.pow(tc.levelParallel, 2)
和Math.pow(2,tc.levelParallel)
的调用也不同。
还要注意内存一致性问题。我一眼就没有看到,虽然你在这里和那里变异记忆。
这里的基本思想是遍历树,并像在 processTree
方法中那样计算结果。但是,一旦达到应该开始并行计算的级别(levelParallel
),您就会生成一个 实际上 在内部调用 processTree
的任务.这将处理树的剩余部分。
processTreeParallel 0
/ \
/ \
processTreeParallel 1 2
/ \ / \
processTreeParallel 3 4 5 6 <- levelParallel
| | | |
processTree call for each: v v v v
+---------------+
tasks for executor: |T T T T |
+---------------+
completion service |
fetches tasks and v
sums them up: T+T+T+T -> result
然后您必须添加由 processTreeParallel
方法的 sequential 部分计算的结果,以及由完成总结的任务结果服务。
processTreeParallel
方法可以这样实现:
long processTreeParallel(TreeCalculation tc)
{
if (level < tc.levelParallel)
{
long leftResult = left.processTreeParallel(tc);
long rightResult = right.processTreeParallel(tc);
return value + leftResult + rightResult;
}
tc.incrementTasks();
tc.cs.submit(new Callable<Long>()
{
@Override
public Long call() throws Exception
{
return processTree();
}
});
return 0;
}
完整的程序显示在这里:
import java.util.concurrent.*;
public class TreeCalculation
{
// tree level to go parallel
int levelParallel;
// total number of generated tasks
long totalTasks;
// current number of open tasks
long nTasks;
// total height of tree
int height;
// Executors
ExecutorService exec;
CompletionService<Long> cs;
TreeCalculation(int height, int levelParallel)
{
this.height = height;
this.levelParallel = levelParallel;
}
void incrementTasks()
{
++nTasks;
++totalTasks;
}
void decrementTasks()
{
--nTasks;
}
long getNTasks()
{
return nTasks;
}
// Where the ExecutorService should be initialized
// with a specific threadCount
void preProcess(int threadCount)
{
exec = Executors.newFixedThreadPool(threadCount);
cs = new ExecutorCompletionService<Long>(exec);
nTasks = 0;
totalTasks = 0;
}
// Where the CompletionService should collect the results;
long postProcess()
{
exec.shutdown();
long result = 0;
for (int i = 0; i < (long) Math.pow(2, levelParallel); i++)
{
try
{
result += cs.take().get();
decrementTasks();
}
catch (Exception e)
{
e.printStackTrace();
}
}
return result;
}
public static void main(String[] args)
{
int height = 22;
int levelParallel = 3;
int threadCount = 4;
if (args.length != 3)
{
System.out.println(
"usage: java Tree treeHeight levelParallel nthreads\n");
System.out.println("Using default values for test");
}
else
{
height = Integer.parseInt(args[0]);
levelParallel = Integer.parseInt(args[1]);
threadCount = Integer.parseInt(args[2]);
}
TreeCalculation tc = new TreeCalculation(height, levelParallel);
// generate balanced binary tree
Tree t = Tree.genTree(height, height);
// traverse sequential
long t0 = System.nanoTime();
long p1 = t.processTree();
double t1 = (System.nanoTime() - t0) * 1e-9;
t0 = System.nanoTime();
tc.preProcess(threadCount);
long p2 = t.processTreeParallel(tc);
p2 += tc.postProcess();
double t2 = (System.nanoTime() - t0) * 1e-9;
long ref = (Tree.counter * (Tree.counter + 1)) / 2;
if (p1 != ref)
System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
if (p1 != p2)
System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
if (tc.totalTasks != (1 << levelParallel))
{
System.out.printf("ERROR: ntasks %d != %d\n", 1 << levelParallel,
tc.totalTasks);
}
// print timing
System.out.printf("tree height: %2d\n"
+ "sequential: %.6f\n"
+ "parallel with %3d threads and %6d tasks: %.6f\n"
+ "speedup: %.3f count: %d\n",
height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
}
}
// ============================================================================
class Tree
{
static long counter; // counter for consecutive node numbering
int level; // node level
long value; // node value
Tree left; // left child
Tree right; // right child
// constructor
Tree(long value)
{
this.value = value;
}
// generate a balanced binary tree of depth k
static Tree genTree(int k, int height)
{
if (k < 0)
{
return null;
}
Tree t = new Tree(++counter);
t.level = height - k;
t.left = genTree(k - 1, height);
t.right = genTree(k - 1, height);
return t;
}
// ========================================================================
// traverse a tree sequentially
long processTree()
{
return value
+ ((left == null) ? 0 : left.processTree())
+ ((right == null) ? 0 : right.processTree());
}
// ========================================================================
// traverse a tree parallel
long processTreeParallel(TreeCalculation tc)
{
if (level < tc.levelParallel)
{
long leftResult = left.processTreeParallel(tc);
long rightResult = right.processTreeParallel(tc);
return value + leftResult + rightResult;
}
tc.incrementTasks();
tc.cs.submit(new Callable<Long>()
{
@Override
public Long call() throws Exception
{
return processTree();
}
});
return 0;
}
}