프로그래밍 연습장

Heavy Light Decomposition의 구현 본문

알고리즘

Heavy Light Decomposition의 구현

김준원 2017. 12. 3. 14:08

Heavy Light Decomposition, HLD는 트리에 관한 동적 쿼리에 사용될 수 있는 대표적인 알고리즘이다. 비선형 자료구조인 트리를 여러 개의 선형 경로로 분리하여, 선형 자료구조에서 쓸 수 있었던 segment tree 등의 기법을 이용할 수 있게 해 준다. 그러면 각 경로를 하나의 정점으로 압축할 수 있게 되고, 어떠한 형태의 트리이든 높이가 최대 $O( \log n)$인 트리에 대한 쿼리로서 해결할 수 있게 된다.


HLD에 대한 자세한 설명은 다른 게시글들에 더 자세히 설명이 되어 있기에 건너뛴다. 간단히, <가장 '무거운' 간선을 따라 내려간 경로로 트리를 묶으면, 트리 위의 어떠한 경로도 길이가 $O( \log n)$을 넘지 않는다>가 핵심적인 원리다. '균형이 잡히지 않은' 트리를 강제로 균형이 잡히게 만드는 셈이다.


HLD가 어떻게 해서 효과적인 성능을 발휘하는지는 다음 게시글들을 참조하자.


(영문) https://blog.anudeep2011.com/heavy-light-decomposition/

(한글) http://theyearlyprophet.com/heavy-light-decomposition.html (다소 다른 관점으로 바라보고 구현하였으나, 원리는 같다)


사진 출처 #

HL Decomposed Graph


HLD의 구현은 크게 세 가지로 나누어지게 된다.


1. HLD의 체인들을, 원래의 루트가 없었던 트리에서 루트를 잡고, 형성한다.


2. 경로 관련 쿼리를 처리하기 위해, LCA(최소 공통 조상)를 구하기 위한 sparse table과, 각 경로들에 대한 세그먼트 트리를 구성한다. 세그먼트 트리는 PST를 구현할 때와 같이 포인터 또는 배열 기반으로 만들 수 있다.


3. 각 쿼리를 체인을 넘나들면서 처리하도록 구현한다.


특히 HLD같이 복잡한 구현이 필요한 자료구조를 한 번에 코딩하기는 굉장히 어렵다. 여러 개로 과제를 분할한 다음에 생각하는 것이 중요하다.


읽기 전에, 이 문제는 acmicpc.net의 13510번 문제를 푼 코드를 그대로 가져왔다. 특정한 구현 방식에 매몰될 필요는 없다.


먼저, HLD의 체인들을 구성하는 부분이다. 루트 없는 트리에서 루트를 잡고 정점들의 깊이, 부모, 무게를 알아내는 첫 DFS인 dfs0 함수. 그리고 루트로부터 시작해 가장 무거운(부트리의 크기가 큰) 자식들이 아닌 자식들에게 새로운 체인을 부여해가며 체인을 형성해가는 dfs 함수이다.


1. dfs0 함수는 DFS로 파악할 수 있는 그래프의 기본적인 정보들을 파악하게 해 준다. DFS를 시작한 임의의 정점(임의이므로 보통 1번이라고 간주한다)을 이 트리의 루트라고 간주하고, 각 정점에 대해 루트와의 거리(깊이), 부모 정점의 번호, 부트리에 있는 정점의 수(무게)를 재귀적으로 계산해 나간다. DFS를 알고 있다면 분석하기 어렵지 않을 것이다.


2. 두번째 DFS인 dfs 함수가 핵심적이다. dfs0에서 잡았던 1번 정점을 계속 루트라고 가정하고 구현한다. 먼저, 루트는 체인의 시작이다. 그리고, 각 정점에 대해서, 가장 크기가 큰 자식만이 그 정점의 체인을 잇고, 다른 자식들은 새로운 체인의 시작이 된다. 코드는 이를 말 그대로 나타내고 있으며, (정점 번호, 정점과 부모를 잇는 간선의 무게)의 pair로 체인을 이루고 있다. 체인의 head, 즉 체인에서 가장 깊이가 낮은 정점부터 차례대로 들어가게 된다. 쓰이는 emplace 함수에 대해서는 #를 참조하라.


참고로, 3, 12, 15번째 줄에서 나타나는 for문의 형태는 간선 리스트로의 그래프의 구현에서 정점과 연결된 다른 정점들을 순회하는 반복문이다. b[i]는 k와 연결된 정점 하나의 번호이고, w[i]는 (k, b[i])를 잇는 간선의 가중치이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int dfs0(int k, int p, int d) {
    dep[k] = d; prt[k] = p; sz[k] = 1;
    for (int i=st[k]; i; i=nxt[i])
        if (b[i] != p) sz[k] += dfs0(b[i], k, d+1);
    return sz[k];
}
 
vector<vector<pint> > grp;
void dfs(int k, int nw, vector<pint>& v) {
    v.emplace_back(k, nw);
    int heavy = 0, mx = 0;
    for (int i=st[k]; i; i=nxt[i]) if (b[i] != prt[k]) {
        if (mx < sz[b[i]]) heavy = b[i], mx = sz[b[i]];
    }
    for (int i=st[k]; i; i=nxt[i]) if (b[i] != prt[k]) {
        if (b[i] == heavy) dfs(b[i], w[i], v);
        else {
            grp.emplace_back();
            dfs(b[i], w[i], grp[grp.size()-1]);
        }
    }
}
 
cs


이 코드를 실행한 다음에는, 2차원 벡터인 grp에는 각 체인에 속하는 각 정점에 대해 (정점 번호, 정점과 부모를 연결하는 간선의 가중치)의 pair가 저장되어 있다. grp에 대해 매번 emplace_back을 하는 것은 부하가 크고, 자칫하면 런타임 에러를 야기할 수 있으므로 위를 실행하기 전에 grp.reserve(maxn);으로 미리 최대 크기의 메모리를 잡아놓는 것이 좋다.


다음, LCA의 전처리는 건너뛰고, 다음은 2와 3을 처리하기 위해 구성한 세그먼트 트리이다. construct 함수는 전처리를 위한 것이고, update와 query 함수는 쿼리 처리를 하기 위한 것이다. 구조체 안에서의 연산은 각각의 체인에서에 한정되어 있으므로, 실제 체인 안에서 이루어지지 않고 임의의 경로에 대해 이루어지는 쿼리를 처리할 때에는 다른 외부 함수의 도움이 필요하다. 기본적으로 PST를 구현했을 때에 사용한 것과 거의 동일하다. 각 노드는 $t$ 배열에 들어 있으며, $root$ 배열에 들어있는 인덱스의 노드는 각 체인에 해당하는 트리의 루트이다. $sz$는 각 트리의 크기이다. 각 노드 $x$는 $l[x]$에 왼쪽 자식을, $r[x]$에 오른쪽 자식을 두고 있으며(부모에는 관심이 없다) 각 함수는 자식들에게 재귀적으로 함수를 보내서 결과를 종합하여 반환한다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
struct tree {
    int root[maxn], sz[maxn], t[maxn*2], l[maxn*2], r[maxn*2], now;
    int construct(vector<pint>& v, int s, int e) {
        int x = now++;
        if (s==e) t[x] = v[s].y;
        else {
            int m = (s+e)/2;
            l[x] = construct(v, s, m);
            r[x] = construct(v, m+1, e);
            t[x] = max(t[l[x]], t[r[x]]);
        }
        return x;
    } void construct(vector<pint>& v, int k) { root[k] = construct(v, 0, v.size()-1); sz[k] = v.size(); }
    void update(int s, int e, int x, int p, int v) {
        if (s==e) t[x] = v;
        else {
            int m = (s+e)/2;
            if (p<=m) update(s, m, l[x], p, v);
            else update(m+1, e, r[x], p, v);
            t[x] = max(t[l[x]], t[r[x]]);
        }
    } void update(int k, int p, int v) { update(0, sz[k]-1, root[k], p, v); }
    int query(int s, int e, int x, int ql, int qr) {
        if (ql<=s and e<=qr) return t[x];
        else if (ql<=e and s<=qr) {
            int m = (s+e)/2;
            return max(query(s, m, l[x], ql, qr), query(m+1, e, r[x], ql, qr));
        } return 0;
    } int query(int k, int l, int r) { return query(0, sz[k]-1, root[k], l, r); }
} t;
 
 
 
cs


그리고, 중요한 전처리를 하나 더 거치게 되는데, 만들어진 체인을 순회하면서 각 정점이 어떤 체인의 몇 번째 원소에 속해 있는지를 저장해두게 된다. 체인 배열의 역배열을 하나 만들었다고 생각하면 된다. 이 정보를 이용해 어떤 체인에서 어떤 범위의 정점에서 쿼리를 수행해야 하는지 알 수 있다.


여기까지 구현했다면 쿼리를 구현하는 것은 쉽다. 쿼리 $(u, v)$를 $l = lca(u, v)$를 이용해 $(u, l), (v, l)$로 분리한다. 각각의 $(v, l)$에 대해, $v$와 $l$이 같은 체인에 속해있다면 아까 구성한 트리로 쿼리를 처리하면 되고, 아니면 그 체인의 가장 위에 있는 노드인 $h$를 가지고 와서 $(v, h), (p[h], l)$로 쿼리를 분리해 재귀적으로 처리하면 된다.


1
2
3
4
5
6
7
8
9
int down(int x, int p) {
    if (id[x].x == id[p].x) return t.query(id[x].x, id[p].y+1, id[x].y);
    return max(t.query(id[x].x, 0, id[x].y), down(prt[grp[id[x].x][0].x], p));
}
 
int query(int x, int y) {
    int l = lca(x, y);
    return max(down(x, l), down(y, l));
}
cs


아래에 있는 전체 코드에서 이를 확인할 수 있다. 이것은 acmicpc.net의 13510번 문제를 푼 코드이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <bits/stdc++.h>
using namespace std;
 
typedef pair<intint> pint;
#define x first
#define y second
 
const int maxn = 100004;
int b[maxn*2], w[maxn*2], nxt[maxn*2], st[maxn], np = 2;
void add(int x, int y, int wt) {
    b[np] = y; w[np] = wt; nxt[np] = st[x]; st[x] = np++;
}
int n, dep[maxn], prt[maxn], sz[maxn], spa[maxn][20], ex[maxn], ey[maxn];
pint id[maxn];
vector<vector<pint> > grp;
 
struct tree {
    int root[maxn], sz[maxn], t[maxn*2], l[maxn*2], r[maxn*2], now;
    int construct(vector<pint>& v, int s, int e) {
        int x = now++;
        if (s==e) t[x] = v[s].y;
        else {
            int m = (s+e)/2;
            l[x] = construct(v, s, m);
            r[x] = construct(v, m+1, e);
            t[x] = max(t[l[x]], t[r[x]]);
        }
        return x;
    } void construct(vector<pint>& v, int k) { root[k] = construct(v, 0, v.size()-1); sz[k] = v.size(); }
    void update(int s, int e, int x, int p, int v) {
        if (s==e) t[x] = v;
        else {
            int m = (s+e)/2;
            if (p<=m) update(s, m, l[x], p, v);
            else update(m+1, e, r[x], p, v);
            t[x] = max(t[l[x]], t[r[x]]);
        }
    } void update(int k, int p, int v) { update(0, sz[k]-1, root[k], p, v); }
    int query(int s, int e, int x, int il, int ir) {
        if (il<=s and e<=ir) return t[x];
        else if (il<=e and s<=ir) {
            int m = (s+e)/2;
            return max(query(s, m, l[x], il, ir), query(m+1, e, r[x], il, ir));
        } return 0;
    } int query(int k, int l, int r) { return query(0, sz[k]-1, root[k], l, r); }
} t;
 
int dfs0(int k, int p, int d) {
    dep[k] = d; prt[k] = p; sz[k] = 1;
    for (int i=st[k]; i; i=nxt[i]) if (b[i] != p) sz[k] += dfs0(b[i], k, d+1);
    return sz[k];
}
 
void dfs(int k, int nw, vector<pint>& v) {
    v.emplace_back(k, nw);
    int heavy = 0, mx = 0;
    for (int i=st[k]; i; i=nxt[i]) if (b[i] != prt[k]) {
        if (mx < sz[b[i]]) heavy = b[i], mx = sz[b[i]];
    }
    for (int i=st[k]; i; i=nxt[i]) if (b[i] != prt[k]) {
        if (b[i] == heavy) dfs(b[i], w[i], v);
        else {
            grp.emplace_back();
            dfs(b[i], w[i], grp[grp.size()-1]);
        }
    }
}
 
int lca(int x, int y) {
    if (dep[x] > dep[y]) return lca(y, x);
    for (int i=19, d = dep[y]-dep[x]; ~i; i--if ((d>>i)&1) y = spa[y][i];
    for (int i=19; ~i; i--if (spa[x][i] != spa[y][i]) x = spa[x][i], y = spa[y][i];
    return x == y ? x : spa[x][0];
}
 
int down(int x, int p) {
    if (id[x].x == id[p].x) return t.query(id[x].x, id[p].y+1, id[x].y);
    return max(t.query(id[x].x, 0, id[x].y), down(prt[grp[id[x].x][0].x], p));
}
 
int query(int x, int y) {
    int l = lca(x, y);
    return max(down(x, l), down(y, l));
}
 
int main()
{
    scanf("%d"&n);
    for (int i=1; i<n; i++) {
        int u, v, w;
        scanf("%d%d%d"&u, &v, &w); ex[i] = u; ey[i] = v;
        add(u, v, w); add(v, u, w);
    }
 
    grp.reserve(100000);
    dfs0(111);
    grp.emplace_back();
    dfs(10, grp[0]);
 
    for (int i=0; i<grp.size(); i++)
        for (int j=0; j<grp[i].size(); j++) id[grp[i][j].x] = pint(i, j);
 
    for (int i=1; i<=n; i++) spa[i][0= prt[i];
    for (int j=1; j<20; j++for (int i=1; i<=n; i++) spa[i][j] = spa[spa[i][j-1]][j-1];
 
    for (int i=0; i<grp.size(); i++) t.construct(grp[i], i);
 
    int q;
    scanf("%d"&q);
 
    while (q--) {
        int qt, x, y;
        scanf("%d%d%d"&qt, &x, &y);
        if (qt==1) {
            x = dep[ex[x]] < dep[ey[x]] ? ey[x] : ex[x];
            t.update(id[x].x, id[x].y, y);
        }
        else printf("%d\n", query(x, y));
    }
}
 
cs


Comments