题目

代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
const int M = 2e5+10;
int n, q, cnt[N];
int h[N], e[M], ne[M], idx;
int d[N], f[N][18], dist[N][18];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j == fa) continue;
d[j] = d[u] + 1;
f[j][0] = u;
dist[j][0] = cnt[u];
for(int k = 1; k <= 17; k++)
{
f[j][k] = f[f[j][k-1]][k-1];
dist[j][k] = dist[j][k-1] + dist[f[j][k-1]][k-1];
}
dfs(j, u);
}
}
int lca(int a, int b)
{
int retv = cnt[a] + cnt[b];
if(d[a] < d[b]) swap(a, b);
for(int i = 17; i >= 0; i--)
if(d[f[a][i]] >= d[b])
{
retv += dist[a][i];
a = f[a][i];
}
if(a == b) return retv - cnt[a];
for(int i = 17; i >= 0; i--)
if(f[a][i] != f[b][i])
{
retv += dist[a][i];
retv += dist[b][i];
a = f[a][i];
b = f[b][i];
}
retv += dist[a][0] + dist[b][0];
return retv - cnt[f[a][0]];
}
int main()
{
memset(h, -1, sizeof h);
scanf("%d%d", &n, &q);
for(int i = 1; i < n; i++)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b); add(b, a);
cnt[a]++, cnt[b]++;
}
d[1] = 1;
dfs(1, -1);
while(q--)
{
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", lca(a, b));
}
}