跳转至

线段树简记

约 30 个字    52 行代码  预计阅读时间 1 分钟

1 建树 (buildTree)

/* 对 [l, r] 区间建树,当前根节点为 f */
void buildTree(int l, int r, int f)
{
    if(l == r)
    {
        // 区间长度为 1 时直接返回 a[] 数组中的根节点值
        d[f] = a[l];
        return;
    }
    int m = l + ((r - l) >> 1); // 这里用到了位运算,等价于 `m = (l + r) / 2`
    build(l, m, p << 1), build(m + 1, r, (p << 1) | 1); // p << 1 即 2 * p (左子节点),(p << 1) | 1 为右子节点
    d[f] = d[p << 1] + d[(p << 1) | 1]; // 向下递归相加求值
}

2 区间查询 (query)

2.1 求和 (getSum)

/* [tl, tr] 为查询区间, [nl, nr] 为当前节点包含的区间, nf 为当前节点的编号 */
int getSum(int tl, int tr, int nl, int nr, int nf)
{
    if(tl <= nl && nr <= tr) return d[f]; // 目标区间包含当前区间时,返回当前区间的节点值(和)
    int m = nl + ((nr - nl) >> 1), sum = 0;
    if(b[nf])
    {
        d[nf << 1] += b[nf] * (m - nl + 1), d[(nf << 1) | 1] += b[nf] * (nr - m);
        b[nf << 1] += b[nf], b[(nf << 1) | 1] += b[nf];
        b[nf] = 0; // 清空父节点标记
    }
    int sum = 0;
    if(tl <= m) sum += getSum(tl, tr, nl, m, nf << 1);
    if(tr > m) sum += getSum(tl, tr, m, nr, (nf << 1) | 1);
    return sum; // 返回求和结果
}

3 区间修改 (updateTree)

/* [tl, tr] 为修改区间, dt 为被修改的元素的变化量, [nl, nr] 为当前节点包含的区间, nf 为当前节点的编号 */
void updateTree(int tl, int tr, int dt, int nl, int nr, int nf)
{
    if(tl <= nl && nr <= tr)
    {
        // 如果当前区间被包含在要修改的目标区间,则直接修改当前区间
        d[nf] += (nr - nl + 1) * dt, b[nf] += dt; // (nr - nl + 1) 为当前区间长度,b[] 数组为变化量标记
        return;
    }
    int m = nl + ((nr - nl) >> 1);
    /* 向下传递标记 */
    if(b[nf] && nl != nr)
    {
        // 如果当前节点被打上了标记,则更新该节点下两个子节点的值并将标记传递给子节点
        d[nf << 1] += b[nf] * (m - nl + 1), d[(nf << 1) | 1] += b[nf] * (nr - m);
        b[nf << 1] += b[nf], b[(nf << 1) | 1] += b[nf]; // 将标记下传给子节点
        b[nf] = 0; // 清空父节点标记
    } 
    /* 递归更新子区间 */
    if(tl <= m) update(tl, tr, dt, nl, m, nf << 1);
    if(tr > m) update(tl, tr, dt, m, nr, (nf << 1) | 1);
    d[nf] = d[nf << 1] + d[(nf << 1) | 1];
}

先这样吧 qwq

评论