SPOJ-Count on a tree(COT)【主席树+LCA】

求树上路径第k大。

主席树。从根往孩子建。查询(x,y)路径时,在第x棵主席树+第y棵主席树-两倍的第lca(x,y)棵主席树上查询即可。还要特判一下lca(x,y)这个节点。

我写的速度慢出翔。QAQ

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=200010,maxm=500010,maxt=4000100;
int p[maxm],n1[maxm],h[maxn],T[maxn],lc[maxt],rc[maxt],s[maxt],a[maxn],bin[maxn];
int val[maxn],anc[maxn][20],dep[maxn],LcaVal,tot=0,ee=0,n,m,q[maxn],vis[maxn]={0};
char buf[8000000],*pt=buf,*o=buf;
inline int getint(){
    int s=0; while(*pt<'0'||*pt>'9')pt++;
    while(*pt>='0'&&*pt<='9')s=s*10+*pt++-48; return s;
}inline void print(int x){
    char str[12],*p=str; if(!x)*o++=48;
    else{ while(x) *p++=x%10+48,x/=10; while(p--!=str)*o++=*p;};
}inline void ae(int a,int b){
	p[ee]=b;	n1[ee]=h[a];	h[a]=ee++;
	p[ee]=a;	n1[ee]=h[b];	h[b]=ee++;
}void build(int l,int r,int &x){
	s[x=++tot]=0;
	if(l==r)	return;
	int mid=l+r>>1;
	build(l,mid,lc[x]);
	build(mid+1,r,rc[x]);
}void modify(int p,int v,int l,int r,int &x){
	s[x=++tot]=s[p]+1;
	lc[x]=lc[p];
	rc[x]=rc[p];
	if(l==r)	return;
	int mid=l+r>>1;
	if(v<=mid)	modify(lc[p],v,l,mid,lc[x]);
	else		modify(rc[p],v,mid+1,r,rc[x]);
}int ask(int a,int b,int c,int l,int r,int k){
	if(l==r)	return l;
	int cnt=s[lc[a]]+s[lc[b]]-(s[lc[c]]<<1),mid=l+r>>1;
	if(LcaVal>=l && LcaVal<=(l+r>>1))	cnt++;
	if(k<=cnt)	return ask(lc[a],lc[b],lc[c],l,mid,k);
	else		return ask(rc[a],rc[b],rc[c],mid+1,r,k-cnt);
}int lca(int u,int v){
	if(dep[u]<dep[v])	swap(u,v);
	int k=dep[u]-dep[v],j=0;
	while(k){
		if(k&1)	u=anc[u][j];
		k>>=1,j++;
	}if(u==v)	return u;
	for(int i=16;~i;i--)
		if(anc[u][i]!=anc[v][i])	u=anc[u][i],v=anc[v][i];
	return anc[u][0];
}int main(){
	fread(buf,1,8000000,stdin);
	n=getint();	m=getint();
	for(int i=1;i<=n;i++)	bin[i]=a[i]=getint();
	memset(h,-1,sizeof(h));
	int x,y;
	for(int i=1;i<n;i++)	ae(getint(),getint());
	sort(bin+1,bin+n+1);
	for(int i=1;i<=n;i++)	a[i]=lower_bound(bin+1,bin+n+1,a[i])-bin;
	build(1,n,T[0]);
	int head=0,tail=1;
	q[0]=vis[1]=dep[1]=1;
	modify(T[0],a[1],1,n,T[1]);
	while(head<tail){
		int u=q[head++];
		for(int i=h[u];~i;i=n1[i]){
			if(!vis[p[i]]){
				modify(T[u],a[p[i]],1,n,T[p[i]]);
				vis[p[i]]=1;
				dep[p[i]]=dep[u]+1;
				anc[p[i]][0]=u;
				q[tail++]=p[i];
			}
		}
	}for(int j=1;j<=16;j++)	
		for(int i=1;i<=n;i++)
			anc[i][j]=anc[anc[i][j-1]][j-1];
	int la=0,z,k;
	while(m--){
		x=getint();	y=getint();	k=getint();
		z=lca(x,y);
		LcaVal=a[z];
		if(x==y)	print(bin[a[x]]);
		else	print(bin[ask(T[x],T[y],T[z],1,n,k)]);
		if(m)	*o++='\n';
	}return fwrite(buf,1,o-buf,stdout),0;
}