给定一个具有节点权重的二分图,根据一定的启发式得到一种类型节点的有序列表

Given a bipartite graph with node weight get ordered list of one type of node based on certain heuristic

给定一个 bi-partite 图,其中 A 类和 B 类节点的节点权重如下所示:

我想输出由以下启发式定义的 B 类节点的有序列表:

  1. 对于类型 B 的每个节点,我们对该节点具有边缘的类型 A 的节点权重求和,并将总和乘以其自身的节点权重以获得节点值。
  2. 然后我们select类型B中具有最高值的节点并将其附加到输出集S。
  3. 我们从类型 B 中删除 selected 节点,并从类型 A 中删除它有一条边的所有节点。
  4. 返回到步骤 1,直到类型 B 中的任何节点都留下了与类型 A 中的节点的边。
  5. 将类型 B 的任何剩余节点按其节点权重的顺序附加到输出集中。

下图为示例:

对于这个例子,输出集将是:(Y, Z, X)

天真的过程是简单地遍历这个算法,但假设 bi-partite 图很大,我正在寻找找到输出集的最有效方法。请注意,我只需要 B 型节点的有序列表作为输出 而没有 中间计算值(例如 50、15、2)

这是 Dave 在评论中建议的算法的进一步改进。它最大限度地减少了需要重新计算节点值的次数。

  1. 运行通过第1步,通过val
  2. 将生成的B节点放入最大堆中
  3. 检查顶级节点是否删除了它的任何邻居。如果是,则重新计算并重新插入到堆中。如果没有,添加到输出并删除邻居。
  4. 重复2直到所有B都输出

我已经根据我的 PathFinder graph class 在 C++ 中实现了这个算法。代码 运行ning 在具有一半 a 和一半 b 节点的 100 万节点图上,每个 b 节点连接到两个随机 a 节点,需要 1 秒。

这是代码

void cPathFinder::karup()
    {
        raven::set::cRunWatch aWatcher("karup");
        std::cout << "karup on " << nodeCount() << " node graph\n";
        std::vector<int> output;

        // calculate initial values of B nodes
        std::multimap<int, int> mapValueNode;
        for (auto &b : nodes())
        {
            if (b.second.myName[0] != 'b')
                continue;
            int value = 0;
            for (auto a : b.second.myLink)
            {
                value += node(a.first).myCost;
            }
            value *= b.second.myCost;
            mapValueNode.insert(std::make_pair(value, b.first));
        }

        // while not all B nodes output
        while (mapValueNode.size())
        {
            raven::set::cRunWatch aWatcher("select");

            // select node with highest value
            auto remove_it = --mapValueNode.end();
            int remove = remove_it->second;

            if (!remove_it->first)
            {
                /** all remaining nodes have zero value
                 * all the links from B nodes to A nodes have been removed
                 * output remaining nodes in order of decreasing node weight
                 */
                raven::set::cRunWatch aWatcher("Bunlinked");
                std::multimap<int, int> mapNodeValueNode;
                for (auto &nv : mapValueNode)
                {
                   mapNodeValueNode.insert( 
                       std::make_pair( 
                           node(nv.second).myCost,
                           nv.second ));
                }
                for( auto& nv : mapNodeValueNode )
                {
                    myPath.push_back( nv.second );
                }
                break;
            }

            bool OK = true;
            int value = 0;
            {
                raven::set::cRunWatch aWatcher("check");

                // check that no nodes providing value have been removed

                // std::cout << "checking neighbors of " << name(remove) << "\n";

                auto &vl = node(remove).myLink;
                for (auto it = vl.begin(); it != vl.end();)
                {
                    if (!myG.count(it->first))
                    {
                        // A neighbour has been removed
                        OK = false;
                        it = vl.erase(it);
                    }
                    else
                    {
                        // A neighbour remains
                        value += node(it->first).myCost;
                        it++;
                    }
                }
            }

            if (OK)
            {
                raven::set::cRunWatch aWatcher("store");
                // we have a node whose values is highest and valid

                // store result
                output.push_back(remove);

                // remove neighbour A nodes
                auto &ls = node(remove).myLink;
                for (auto &l : ls)
                {
                    myG.erase(l.first);
                }
                // remove the B node
                // std::cout << "remove " << name( remove ) << "\n";
                mapValueNode.erase(remove_it);
            }
            else
            {
                // replace old value with new
                raven::set::cRunWatch aWatcher("replace");
                value *= node(remove).myCost;
                mapValueNode.erase(remove_it);
                mapValueNode.insert(std::make_pair(value, remove));
            }
        }
    }

这是计时结果

karup on 1000000 node graph
raven::set::cRunWatch code timing profile
Calls           Mean (secs)     Total           Scope
       1        1.16767 1.16767 karup
  581457        1.37921e-06     0.801951        select
  581456        4.71585e-07     0.274206        check
  564546        3.04042e-07     0.171646        replace
       1        0.153269        0.153269        Bunlinked
   16910        8.10422e-06     0.137042        store

我在 C++ 中提供了一个解决方案,它基本上类似于@ravenspoint 的想法。它维护一个堆,每次取值最大的B节点。这里我使用 priority_queue 而不是 set 因为第一个比第二个快得多。


#include <chrono>
#include <iostream>
#include <queue>
#include <vector>

int nA, nB;
std::vector<int> A, B, sum;
std::vector<std::vector<int>> adjA, adjB;
inline std::vector<int> solve() {
    struct Node {
        // We store the value of the node `x` WHEN IT IS INSERTED
        // Modifying the value of the node `x` (sum) won't affect this Node basically
        int x, val;

        Node(int x): x(x), val(sum[x] * B[x]) {}

        bool operator<(const Node &t) const { return val == t.val? (B[x] < B[t.x]): (val < t.val); }
    };

    std::priority_queue<Node> q;
    std::vector<bool> delA(nA, false), delB(nB, false);
    std::vector<int> ret; ret.reserve(nB);

    for (int x = 0; x < nA; ++x)
        for (int y : adjA[x]) sum[y] += A[x];
    for (int y = 0; y < nB; ++y) q.emplace(y);
    while (!q.empty()) {
        const Node node = q.top(); q.pop();
        const int y = node.x;
        if (sum[y] * B[y] != node.val || delB[y]) // This means this Node is obsolete
            continue;
        delB[y] = true;
        ret.push_back(y);
        for (int x : adjB[y]) {
            if (delA[x]) continue;
            delA[x] = true;
            for (int ny : adjA[x]) {
                if (delB[ny]) continue;
                sum[ny] -= A[x];
                // This happens at most `m` time
                q.emplace(ny);
            }
        }
    }

    return ret;
}
int main() {
    std::cout << "Number of nodes in type A: "; std::cin >> nA;
    A.resize(nA); adjA.resize(nA);
    std::cout << "Weights of nodes in type A: ";
    for (int &v : A) std::cin >> v;

    std::cout << "Number of nodes in type B: "; std::cin >> nB;
    B.resize(nB); adjB.resize(nB); sum.resize(nB, 0);
    std::cout << "Weights of nodes in type B: ";
    for (int &v : B) std::cin >> v;

    int m;
    std::cout << "Number of edges: "; std::cin >> m;
    std::cout << "Edges: " << std::endl;
    for (int i = 0; i < m; ++i) {
        int x, y; std::cin >> x >> y;
        --x; --y;
        adjA[x].push_back(y);
        adjB[y].push_back(x);
    }

    auto st_time = std::chrono::steady_clock::now();
    auto ret = solve();
    auto en_time = std::chrono::steady_clock::now();
    std::cout << "Answer:";
    for (int v : ret) std::cout << ' ' << (v + 1);
    std::cout << std::endl;

    std::cout << "Took "
        << std::chrono::duration_cast<std::chrono::milliseconds>(en_time - st_time).count()
        << "ms" << std::endl;
}

我在 nA = nB = 1e6, m = 2e6 处随机生成了一些批次的数据,并且程序总是可以在我的计算机上不到 800ms 的时间内产生答案(不考虑 IO 时间,O2 启用)。此解决方案的时间复杂度为 O((m+n)log m),因为 emplace 最多调用 n+m 次。

抱歉我的英语不好。欢迎指出我的错别字和错误。