[UOJ182]a^-1 + b problem
Posted jefflyy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[UOJ182]a^-1 + b problem相关的知识,希望对你有一定的参考价值。
$\newcommand{\align}[1]{\begin{align*}#1\end{align*}}$做这题需要一个前置知识:多项式的多点求值
多项式的多点求值:给定多项式$f(x)$和$x_{1\cdots n}$,要求出$f(1)\cdots f(n)$
首先,我们可以找到$g_i(x)$使得$f(x)=(x-x_i)g_i(x)+C$(就是把$f(x)$对$x-x_i$取模),当$x=x_i$,我们得到$f(x_i)=C$,即$f(x_i)=\left.f(x)\%(x-x_i)\right|_{x=x_i}$,所以我们要求的是$f(x)\%(x-x_i)$,直接对$n$个$x_i$暴力求是$O(n^2\log_2n)$的,比暴力还慢,但一个很显然的事实是:如果$g(x)=h(x)r(x)$,那么$f(x)\%g(x)\%h(x)=f(x)\%h(x)$,所以我们这样分治求解:如果要求出$f(x)$在$x_{l\cdots r}$的取值,那么就递归计算$\align{f(x)\%\prod\limits_{i=l}^r(x-x_i)}$在$x_{l\cdots mid}$和$x_{mid+1\cdots r}$的取值,因为有取模,所以$f(x)$的次数被降了下来,总时间复杂度$T(n)=2T\left(\dfrac n2\right)+O(n\log_2n)=O(n\log_2^2n)$,注意要用分治FFT预处理出$\align{\prod\limits_{i=l}^r(x-x_i)}$,时间复杂度也是$O(n\log_2^2n)$
然后是这道题,因为是全局操作,所以我们定义$f_i(x)$表示经过$i$次操作后,原来的$x$会变成$f_i(x)$,每次操作要么是将$f(x)$加上一个常数,要么是把它取倒数,所以它的形式肯定是$f(x)=\dfrac{ax+b}{cx+d}=p+\dfrac q{x+t}$($c=0$要特殊处理)
所以我们要求的答案是$\align{\sum\limits_{i=1}^nf(x_i)}$,展开得到$\align{pn+q\sum\limits_{i=1}^n\dfrac1{x_i+t}}$,在这个式子中,$x_i$是常数,而$t$随着修改变化($m$个取值),所以我们把它看成关于$t$的函数$\align{g(t)=\sum\limits_{i=1}^n\dfrac1{x_i+t}}=\dfrac{\sum\limits_{i=1}^n\prod\limits_{j\ne i}(x_j+t)}{\prod\limits_{i=1}^n(x_i+t)}$,分母可以分治FFT算,分子是分母的导数,算出来后直接多点求值就做完了...
注意:凡是涉及分治FFT,需要new内存的,一定要注意不能访问超限,这时assert就派上用场了>_<
#include<stdio.h> #include<string.h> #include<assert.h> const int mod=998244353,maxn=262144; typedef long long ll; int mul(int a,int b){return a*(ll)b%mod;} int ad(int a,int b){return(a+b)%mod;} int de(int a,int b){return(a-b)%mod;} void swap(int&a,int&b){ int c=a; a=b; b=c; } int max(int a,int b){return a>b?a:b;} int pow(int a,int b){ int s=1; while(b){ if(b&1)s=mul(s,a); a=mul(a,a); b>>=1; } return s; } int rev[maxn],N,iN; void pre(int n){ int i,k; for(N=1,k=0;N<n;N<<=1)k++; for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1)); iN=pow(N,mod-2); } void ntt(int*a,int on){ int i,j,k,t,w,wn; for(i=0;i<N;i++){ if(i<rev[i])swap(a[i],a[rev[i]]); } for(i=2;i<=N;i<<=1){ wn=pow(3,(on==1)?(mod-1)/i:(mod-1-(mod-1)/i)); for(j=0;j<N;j+=i){ w=1; for(k=0;k<i>>1;k++){ t=mul(w,a[i/2+j+k]); a[i/2+j+k]=de(a[j+k],t); a[j+k]=ad(a[j+k],t); w=mul(w,wn); } } } if(on==-1){ for(i=0;i<N;i++)a[i]=mul(a[i],iN); } } int t0[maxn]; void getinv(int*a,int*b,int n){ if(n==1){ b[0]=pow(a[0],mod-2); return; } int i; getinv(a,b,n>>1); pre(n<<1); memset(t0,0,N<<2); memcpy(t0,a,n<<2); ntt(t0,1); ntt(b,1); for(i=0;i<N;i++)b[i]=mul(b[i],2-mul(b[i],t0[i])); ntt(b,-1); for(i=n;i<N;i++)b[i]=0; } int ta[maxn],tb[maxn],tc[maxn]; void add(int*a,int n,int*b,int m,int*c,int&k){ k=max(n,m); for(int i=0;i<=k;i++)tc[i]=ad(a[i],b[i]); while(k!=0&&tc[k]==0)k--; memcpy(c,tc,(k+1)<<2); } void dec(int*a,int n,int*b,int m,int*c,int&k){ k=max(n,m); for(int i=0;i<=k;i++)tc[i]=de(a[i],b[i]); while(k!=0&&tc[k]==0)k--; memcpy(c,tc,(k+1)<<2); } void reverse(int*a,int n){ for(int i=0;i<=n>>1;i++)swap(a[i],a[n-i]); } void mul(int*a,int n,int*b,int m,int*c,int&k){ int i; k=n+m; pre(k+1); memset(ta,0,N<<2); memset(tb,0,N<<2); memcpy(ta,a,(n+1)<<2); memcpy(tb,b,(m+1)<<2); ntt(ta,1); ntt(tb,1); for(i=0;i<N;i++)tc[i]=mul(ta[i],tb[i]); ntt(tc,-1); memcpy(c,tc,(k+1)<<2); } int t1[maxn]; void div(int*a,int n,int*b,int m,int*c,int&k){ if(n<m){ k=0; return; } int i,rn; for(rn=1;rn<n-m+1;rn<<=1); memset(ta,0,rn<<3); memset(tb,0,rn<<3); memcpy(ta,a,(n+1)<<2); memcpy(tb,b,(m+1)<<2); reverse(tb,m); for(i=rn;i<=m;i++)tb[i]=0; memset(t1,0,rn<<3); getinv(tb,t1,rn); pre(rn<<1); reverse(ta,n); for(i=rn;i<=n;i++)ta[i]=0; ntt(ta,1); ntt(t1,1); for(i=0;i<N;i++)tc[i]=mul(ta[i],t1[i]); ntt(tc,-1); k=n-m; reverse(tc,k); while(k!=0&&tc[k]==0)k--; memcpy(c,tc,(k+1)<<2); } int len; void modulo(int*a,int n,int*b,int m,int*c,int&k){ if(n<m){ k=n; memcpy(c,a,(n+1)<<2); return; } div(a,n,b,m,t1,k); mul(t1,k,b,m,t1,k); //assert(max(n,k)<=len); dec(a,n,t1,k,c,k); } struct frac{//(ax+b)/(cx+d) int a,b,c,d; void add(int k){ a=ad(a,mul(c,k)); b=ad(b,mul(d,k)); } void inv(){ swap(a,c); swap(b,d); } }fr[60010]; int x[100010],op[60010],v[60010],ti[60010],*tr[240010],M; void build(int l,int r,int x){ if(l==r){ tr[x]=new int[2]; tr[x][0]=-ti[l]; tr[x][1]=1; return; } int mid=(l+r)>>1; build(l,mid,x<<1); build(mid+1,r,x<<1|1); tr[x]=new int[r-l+2]; mul(tr[x<<1],mid-l+1,tr[x<<1|1],r-mid,tr[x],x); } void solve(int*f,int n,int l,int r,int x,int*ans){ int mid=(l+r)>>1,*now; now=new int[r-l+1]; len=r-l; modulo(f,n,tr[x],r-l+1,now,n); if(l==r){ ans[l]=now[0]; return; } solve(now,n,l,mid,x<<1,ans); solve(now,n,mid+1,r,x<<1|1,ans); } int t2[maxn],t3[maxn]; int*solve2(int l,int r){ int mid,*res,*L,*R,len; res=new int[r-l+2]; if(l==r){ res[1]=1; res[0]=x[l]; return res; } mid=(l+r)>>1; L=solve2(l,mid); R=solve2(mid+1,r); mul(L,mid-l+1,R,r-mid,res,len); return res; } int ans1[60010],ans2[60010],ans[60010],up[100010]; int main(){ int n,m,i,p,q,del,*res; scanf("%d%d",&n,&m); for(i=1;i<=n;i++)scanf("%d",x+i); fr->a=fr->d=1; fr->b=fr->c=0; for(i=1;i<=m;i++){ fr[i]=fr[i-1]; scanf("%d",op+i); if(op[i]==1){ scanf("%d",v+i); fr[i].add(v[i]); }else fr[i].inv(); if(op[i]==2){ M++; ti[M]=mul(fr[i].d,pow(fr[i].c,mod-2)); } } del=0; for(i=1;i<=n;i++)ans[0]=ad(ans[0],x[i]); if(M==0){ for(i=1;i<=m;i++){ del=ad(del,v[i]); printf("%d\n",ad(ans[0],mul(del,n))); } return 0; } build(1,M,1); res=solve2(1,n); for(i=1;i<=n;i++)up[i-1]=mul(res[i],i); solve(up,n-1,1,M,1,ans1); solve(res,n,1,M,1,ans2); M=del=0; for(i=1;i<=m;i++){ if(op[i]==1){ del=ad(del,v[i]); printf("%d\n",ad(ad(ans[M],mul(n,del)),mod)); }else{ M++; if(fr[i].c==0){ printf("%d\n",ans[M]=ad(mul(ad(mul(fr[i].a,ans[0]),mul(fr[i].b,n)),pow(fr[i].d,mod-2)),mod)); continue; } p=mul(fr[i].a,pow(fr[i].c,mod-2)); q=mul(de(mul(fr[i].b,fr[i].c),mul(fr[i].a,fr[i].d)),pow(mul(fr[i].c,fr[i].c),mod-2)); ans[M]=ad(mul(p,n),mul(q,mul(ans1[M],pow(ans2[M],mod-2)))); printf("%d\n",ad(ans[M],mod)); del=0; } } }
以上是关于[UOJ182]a^-1 + b problem的主要内容,如果未能解决你的问题,请参考以下文章