树上背包NTT优化

发布时间:2022-06-28 发布网站:脚本宝典
脚本宝典收集整理的这篇文章主要介绍了树上背包NTT优化脚本宝典觉得挺不错的,现在分享给大家,也给大家做个参考。

主要结合两道例题讲,复杂度的计算很重要。

LOJ6290 花朵

非常容易可以考虑到树上背包的做法,但是过不了。

怎么将这个 (text{dp}) 优化呢?考虑背包实际上就是一个卷积的形式,所以我们可以用多项式科技优化卷积过程。

可以想到的,我们不能将背包直接卷积,因为复杂度由 (O(nm)) 变成 (O((n+m)log(n+m))) ,在 (n=1) 的时候复杂度反而会变慢。

具体的,我们将原树轻重链剖分,对于一个点,我们先通过类曼哈顿的贪心,每一次将最小的两个轻子树合并,等到一条重链上的点的所有节点的轻子树都合并完成了,我们再在这条重链上做分治 (text{fft})

复杂度是 (O(nlog^2n)) ,具体证明的话应该是考虑重链和轻链分开算。

对于重链来说,考虑单个长度为 (m) 的重链做分治 (text{fft}) 的复杂度是 (O(mlog mlog n)) 的,所以总复杂度是 (O(nlog^2n)) 的。

对于轻链来说,考虑每一个点的贡献,由于每一个点在一次合并操作中的均摊复杂度可以近似看成 (O(log n)) ,所以我们只需要计算出每一个点的操作次数即可。考虑一个点在进行轻子树合并的时候大小一定翻倍,所以其从自己位置一直合并到根的合并次数是一定小于 (log n) 的,所以这里的总复杂度也是 (O(nlog^2n)) 的。


卧槽发现有个直接套距阵的老哥,也太帅了吧。

我顿悟了,距阵掌握度 +1 。


关于矩阵实现的细节很重要。对于矩阵的优化,一定需要写出对应的状态转移方程,然后根据转移前的矩阵和转移后的矩阵写出转移矩阵。

#include<bits/stdc++.h>
using namespace std;
const int N=131072;
const int MOD=998244353,G=3;
int ADD(int x,int y){return x+y>=MOD?x+y-MOD:x+y;}
int TIME(int x,int y){return (int)(1ll*x*y%MOD);}
int ksm(int x,int k=MOD-2){int res=1;for(;k;k>>=1,x=TIME(x,x))if(k&1)res=TIME(res,x);return res;}
int rev[N],lst=0;
void get_rev(int lg){
	if(lst==lg) return ;else lst=lg;
	for(int i=0;i<(1<<lg);++i)
		rev[i]=((rev[i>>1]>>1)|((i&1)<<(lg-1)));
}
struct Polynomial{
	vector<int> f;
	int &operator [] (int x){return assert(x<(int)f.size()),f[x];}
	int len(){return (int)f.size();}void clear(){return f.clear();}
	void resize(int n){
		while((int)f.size()>n) f.pop_back();
		while((int)f.size()<n) f.push_back(0);
	}
	void NTT(int lg,bool tag){
		int n=(1<<lg);get_rev(lg),resize(n);
		for(int i=0;i<n;++i) if(i<rev[i]) swap(f[i],f[rev[i]]);
		for(int len=2;len<=n;len<<=1){
			int m=(len>>1),g=ksm(G,(MOD-1)/len);if(tag) g=ksm(g);
			for(int i=0;i<n;i+=len){
				for(int j=0,gg=1;j<m;++j,gg=TIME(gg,g)){
					int tmp=TIME(f[i+j+m],gg);
					f[i+j+m]=ADD(f[i+j],MOD-tmp),f[i+j]=ADD(f[i+j],tmp);
				}
			}
		}
		if(tag) for(int i=0,tmp=ksm(n);i<n;++i) f[i]=TIME(f[i],tmp);
	}
	void print(){
		for(int i=0;i<len();++i) printf("%d ",f[i]);
		printf("n");
	}
};
Polynomial init(int x){
	Polynomial res;return res.resize(2),res[1]=x,res;
}
Polynomial operator * (Polynomial f,Polynomial g){
	if(!f.len()||!g.len()) return Polynomial();
	int n=f.len()+g.len()-1,lg=0;while((1<<lg)<n) lg++;
	f.NTT(lg,false),g.NTT(lg,false);
	for(int i=0;i<(1<<lg);++i) f[i]=TIME(f[i],g[i]);
	return f.NTT(lg,true),f.resize(n),f;
}
Polynomial operator + (Polynomial f,Polynomial g){
	if(f.len()<g.len()) swap(f,g);
	for(int i=0;i<g.len();++i) f[i]=ADD(f[i],g[i]);
	return f;
}
struct Matrix{
	Polynomial f[2][2];
	void clear(){
		for(int i=0;i<2;++i){
			for(int j=0;j<2;++j)
				f[i][j].clear();
		}
	}
	void print(){
		for(int i=0;i<2;++i){
			for(int j=0;j<2;++j)
			printf("f[%d][%d]=",i,j),f[i][j].print();
		}
	}
};
bool operator < (Matrix a,Matrix b){
	int tmp1=max({a.f[0][0].len(),a.f[0][1].len(),a.f[1][0].len(),a.f[1][1].len()});
	int tmp2=max({b.f[0][0].len(),b.f[0][1].len(),b.f[1][0].len(),b.f[1][1].len()});
	return tmp1<tmp2;
}
bool operator > (Matrix a,Matrix b){
	int tmp1=max({a.f[0][0].len(),a.f[0][1].len(),a.f[1][0].len(),a.f[1][1].len()});
	int tmp2=max({b.f[0][0].len(),b.f[0][1].len(),b.f[1][0].len(),b.f[1][1].len()});
	return tmp1>tmp2;
}
Matrix operator * (Matrix a,Matrix b){
	Matrix res;res.clear();
	for(int i=0;i<2;++i){
		for(int k=0;k<2;++k){
			for(int j=0;j<2;++j)
			res.f[i][j]=res.f[i][j]+a.f[i][k]*b.f[k][j];
		}
	}
	return res;
}
int n,m,p[N];
struct Edge{int nxt,to;}e[N<<1];int fir[N];
void add(int u,int v,int i){e[i]=(Edge){fir[u],v},fir[u]=i;}
struct Node{int fa,son,siz;}tr[N];
void dfs1(int u){
	tr[u].siz=1;
	for(int i=fir[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==tr[u].fa) continue;
		tr[v].fa=u,dfs1(v),tr[u].siz+=tr[v].siz;
		if(tr[v].siz>tr[tr[u].son].siz) tr[u].son=v;
	}
}
Matrix cdq(vector<Matrix> &bag,int l,int r){
	if(l==r) return bag[l];
	int mid=(l+r)>>1;
	return cdq(bag,l,mid)*cdq(bag,mid+1,r);
}
priority_queue<Matrix,vector<Matrix>,greater<Matrix> > q;
Matrix merge(vector<Matrix> &bag){
	while(!q.empty()) q.pop();
	for(int i=0;i<(int)bag.size();++i) q.push(bag[i]);
	while(q.size()>1){
		Matrix a=q.top();q.pop();
		Matrix b=q.top();q.pop();
		q.push(a*b);
	}
	return q.top();
}
Matrix dfs2(int u){
	vector<Matrix> bag;
	for(;u;u=tr[u].son){
		vector<Matrix> BAG;
		for(int i=fir[u];i;i=e[i].nxt){
			int v=e[i].to;if(v==tr[u].fa||v==tr[u].son) continue;
			Matrix tmp=dfs2(v),TMP;TMP.clear();
			TMP.f[1][1]=tmp.f[0][0];
			TMP.f[0][0]=TMP.f[1][1]+tmp.f[1][0];
			BAG.push_back(TMP);
		}
		Matrix tmp,TMP;TMP.clear();
		if(BAG.empty()){
			tmp.clear();
			tmp.f[0][0].resize(1),tmp.f[0][0][0]=1;
			tmp.f[1][1].resize(1),tmp.f[1][1][0]=1;
		}
		else tmp=merge(BAG);
		TMP.f[0][0]=TMP.f[0][1]=tmp.f[0][0];
		TMP.f[1][0]=tmp.f[1][1]*init(p[u]);
		bag.push_back(TMP);
	}
	return cdq(bag,0,(int)bag.size()-1);
}
int main(){
	cin>>n>>m;
	for(int i=1;i<=n;++i) scanf("%d",&p[i]);
	for(int i=1;i<n;++i){
		int u,v;scanf("%d%d",&u,&v);
		add(u,v,i<<1),add(v,u,i<<1|1);
	}
	dfs1(1);Matrix tmp=dfs2(1);Polynomial res;
	res=tmp.f[0][0]+tmp.f[1][0],res.resize(m+1);
	return printf("%dn",res[m]),0;
}

GYM102331J Jiry Matchings

我们考虑这里的合并的复杂度是 (O(n+m)) 的,也是需要轻重链剖分的。

具体操作和上面一样,分析出来的复杂度好像是 (O(nlog n)) 的,嘿嘿嘿比孔老爷分析得快,只要我分析的比他快我就比他跑得快

脚本宝典总结

以上是脚本宝典为你收集整理的树上背包NTT优化全部内容,希望文章能够帮你解决树上背包NTT优化所遇到的问题。

如果觉得脚本宝典网站内容还不错,欢迎将脚本宝典推荐好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。
标签: