题目链接
前言
Captain Orange 讲课的作业题,同时也是第一次长时间没打某种大数据结构后,第一次不看以前的 讲义/笔记/博客 YY出来了,写文祭之
题目大意
给你一个长度为 $n$ 的数列,当数列中的一个数对 $(i,j)\quad i<=j$ 满足 $a_ia_j<=max(a_i,a_{i+1},a_{i+1},\cdots,a_j)$ ,则称这个数对是美丽的,求出美丽的数对的个数$\qquad 1\le n\le1e5\quad 1\le a_i\le1e9$
题解
题目中要求的是一段区间,所以我们考虑分治(套路)。观察题目,我们发现对于每一段区间,对于答案有影响的只有最大的数,我们可以用 st表 提前预处理出当前区间最大的数的值和位置,设这个位置为 $mid$
得出位置 $mid$ 之后,我们将当前区间分为两部分 $[l,mid]$ 和 $[mid,r]$ ,枚举每一个数对,此时复杂度约为 $O(n^2)$ 。考虑对当前解法进行优化,我们发现只要枚举了其中一个端点之后,我们可以直接计算出另外一个端点的最大值,即为 $\frac{a_{mid}}{a_i}$。
容易发现,在另外一个区间中,只要值小于 $\frac{a_{mid}}{a_i}$ 都对答案有贡献,而区间kth,我们可以离散化后用主席树轻松维护
代码
void build(int &u,int l,int r)
{
if(!u)u=++tot;
if(l==r)return;
int mid=(l+r)>>1;
build(tr[u].ls,l,mid);
build(tr[u].rs,mid+1,r);
}
void insert(int &u,int pre,int l,int r,int x)
{
if(!u)u=++tot;
tr[u]=tr[pre];
++tr[u].si;
if(l==r)return;
int mid=(l+r)>>1;
if(x<=mid)
insert(tr[u].ls=0,tr[pre].ls,l,mid,x);
else
insert(tr[u].rs=0,tr[pre].rs,mid+1,r,x);
}
int query(int t1,int t2,int l,int r,int x,int y)
{
if(y<x)return 0;
if(l>=x&&r<=y)return tr[t2].si-tr[t1].si;
int mid=(l+r)>>1;
int res=0;
if(mid>=x)res+=query(tr[t1].ls,tr[t2].ls,l,mid,x,y);
if(mid<y)res+=query(tr[t1].rs,tr[t2].rs,mid+1,r,x,y);
return res;
}
inline void dfs(int l,int r)
{
if(l>r)return;
int as=lg[r-l+1];
noder t=get_max(st[as][l],st[as][r-(1<<as)+1]);
if(r-t.pos+1>t.pos-l+1)
for(register int i=l;i<=t.pos;++i)
{
int val=lower_bound(b+1,b+1+m,t.val/a[i])-b;
if(b[val]!=t.val/a[i])--val;
ans+=query(rt[t.pos-1],rt[r],1,m,1,val);
}
else
for(register int i=t.pos;i<=r;++i)
{
int val=lower_bound(b+1,b+1+m,t.val/a[i])-b;
if(b[val]!=t.val/a[i])--val;
ans+=query(rt[l-1],rt[t.pos],1,m,1,val);
}
dfs(l,t.pos-1);
dfs(t.pos+1,r);
}
signed main(void)
{
n=read();
for(register int i=1;i<=n;++i)a[i]=b[i]=read(),st[0][i].val=a[i],st[0][i].pos=i;
sort(b+1,b+1+n);
m=unique(b,b+1+n)-b-1;
build(rt[0],1,m);
lg[0]=-1;
for(register int i=1;i<=n;++i)
{
int t=lower_bound(b+1,b+1+m,a[i])-b;
insert(rt[i],rt[i-1],1,m,t),lg[i]=lg[i/2]+1;
}
for(register int j=1;j<=20;++j)
for(register int i=1;i+(1<<j)<=n+1;++i)
st[j][i]=get_max(st[j-1][i],st[j-1][i+(1<<(j-1))]);
dfs(1,n);
printf("%lld\n",ans);
}