51Nod1022 石子归并(环形)[区间DP,四边形不等式优化]

题目链接:https://www.51nod.com/onlineJudge/questionCode.html#!problemId=1022

线性石子合并,改成了环形并且n达到了1000。环形比较好解决,把n堆石子复制一遍放到原来的石子后边,长度仍然枚举到n,区间右边界枚举到2*n,答案就是所有长度为n的区间的解的最大值。
麻烦的是数据范围比较大,O(n^3)的算法无法解决,这里需要用到四边形不等式优化。
优化本身的数学证明比较麻烦,直接记条件和结论了。

四边形不等式优化条件

在动态规划中,经常遇到形如下式的转台转移方程:
m(i,j)=min{m(i,k-1),m(k,j)}+w(i,j)(i≤k≤j)(min也可以改为max)
上述的m(i,j)表示区间[i,j]上的某个最优值。w(i,j)表示在转移时需要额外付出的代价。该方程的时间复杂度为O(N^3)。

下面我们通过四边形不等式来优化上述方程,首先介绍什么是”区间包含的单调性“和”四边形不等式“
(1)区间包含的单调性:如果对于i≤i'<j≤j',有w(i',j)≤w(i,j'),那么说明w具有区间包含的单调性。(可以形象理解为如果小区间包含于大区间中,那么小区间的w值不超过大区间的w值)
(2)四边形不等式:如果对于i≤i'<j≤j',有w(i,j)+w(i',j')≤w(i',j)+w(i,j'),我们称函数w满足四边形不等式。(可以形象理解为两个交错区间的w的和不超过小区间与大区间的w的和)

下面给出两个定理

定理一:如果上述的w函数同时满足区间包含单调性和四边形不等式性质,那么函数m也满足四边形不等式性质。
我们再定义s(i,j)表示m(i,j)取得最优值时对应的下标(即i≤k≤j时,k处的w值最大,则s(i,j)=k)。此时有如下定理
定理二:假如m(i,j)满足四边形不等式,那么s(i,j)单调,即s(i,j)≤s(i,j+1)≤s(i+1,j+1)。

好了,有了上述的两个定理后,我们发现如果w函数满足区间包含单调性和四边形不等式性质,那么有s(i,j-1)≤s(i,j)≤s(i+1,j)。即原来的状态转移方程可以改写为下式:
m(i,j)=min{m(i,k-1),m(k,j)}+w(i,j)(s(i,j-1)≤k≤s(i+1,j))(min也可以改为max) 注:具体代码实现中k取满足条件的最大值或最小值,下文代码取的最大值。

由于这个状态转移方程枚举的是区间长度L=j-i,而s(i,j-1)和s(i+1,j)的长度为L-1,是之间已经计算过的,可以直接调用。不仅如此,区间的长度最多有n个,对于固定的长度L,不同的状态也有n个,故时间复杂度为O(N^2),而原来的时间复杂度为O(N^3),实现了优化!今后只需要根据方程的形式以及w函数是否满足两条性质即可考虑使用四边形不等式来优化了。

引用自XDU_Skyline的博客

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=2009;
const int inf=0x3f3f3f3f;
int dp[maxn][maxn];
int sum[maxn];
int s[maxn][maxn];
int main() {
int n;
scanf("%d",&n);
int x;
for(int i=1;i<=n;i++) {
scanf("%d",&x);
sum[i]=sum[i-1]+x;
}
for(int i=n+1;i<=2*n;i++) {
sum[i]=sum[i-1]-sum[i-n-1]+sum[i-n];
}
for(int i=0;i<maxn;i++)
for(int j=0;j<maxn;j++)
dp[i][j] = i==j ? 0 : inf;
for(int i=1;i<2*n;i++) {
dp[i][i+1]=sum[i+1]-sum[i-1];
s[i][i+1]=i+1;
}
for(int len=3;len<=n;len++) {
for(int l=1;l+len-1<=2*n;l++) {
int r=l+len-1;
for(int m=s[l][r-1];m<=s[l+1][r];m++) {
int tmp=dp[l][m-1]+dp[m][r]+sum[r]-sum[l-1];
if(tmp<=dp[l][r]) { //保证s中保存的m是满足dp(i,j)=min{dp(i,m-1),dp(m,j)}+dp(i,j)(s(i,j-1)≤m≤s(i+1,j))的最大值
dp[l][r]=tmp;
s[l][r]=m;
}
}
}
}
int ans=inf;
for(int i=1;i+n-1<=2*n;i++) {
ans=min(ans,dp[i][i+n-1]);
}
printf("%d\n",ans);
}