这是我发的第一篇题解,当时没有加任何LaTeX\LaTeX,所有的LaTeX\LaTeX都是我补充上去的。

我对树形DP又有了更深一层的认识。。。。

方程还是很好想的

f[u][j]f[u][j]为当前u这个点的子树内,分给儿子们jj个点(也即自己留下siz[u]jsiz[u] - j个作为一个联通块)的乘积

于是有f[u][j]=maxf[u][j],f[v][k]+f[jk]f[u][j] = max{f[u][j],f[v][k] + f'[j - k]} (0 <= kk <= jj < siz[u]siz[u])    (ff'数组为之前的儿子所计算出的f[u]f[u]

特别的,对于f[u][siz[u]f[u][siz[u]f[u][siz[u]]f[u][siz[u]] = f[u][j](siz[u]j)f[u][j] * (siz[u] * j) (0 <= jj < siz[u]siz[u])

嗯。。这是蒟蒻的初步想法

飞速打完交了一遍以后。。。。发现他居然爆了long long。。。。。。。。。。

然后打高精。。。。。

接着T了无数遍。。。。。

然后看大爷的题解。。。。

发现大爷只是比我多了一个优化???????

于是引入了一个关于树形DP基本的复杂度的证明:

对于上述方程,我们修改一下,把j - k和k的位置对调一下,变成

f[u][j]=maxf[u][j],f[v][jk]+f[k]f[u][j] = max{f[u][j],f[v][j - k] + f'[k]} (0 <=  jj < siz[u]siz[u] , 0 <= kk <= min(j,siz[u]siz[v]))min(j , siz[u] - siz[v]))

显然他仍然是与原方程等价的,但是复杂度却完全不同

原方程的上限复杂度显然是O(n3)O(n ^ 3)的(不计高精度),

而新方程的实质是对于一个当前u大小为j的子树中,去找不属于v的一棵大小为k的子树

那么这时单次DP的复杂度为O(siz[v](siz[u]siz[v]))O(siz[v] * (siz[u] - siz[v]))

这个优化可能不是很明显,但是我们这样考虑:由于两个子树是不相交的,那么也就可以看做是两个子树内所有的点两两求一遍lca,并且贡献只在lca处计算一次的复杂度

于是总的复杂度为O(n2)O(n ^ 2)

以下为代码

#include<cstdio>
#include<cmath>
#include<queue>
#include<stack>
#include<vector>
#include<algorithm>
#include<cstring>
using namespace std;
 
typedef long long LL;
 
const int INF = 2147483647;
const int maxn = 710;
const int r = 1000000000;
 
struct data{
	int tot;
	LL m[20]; data(){memset(m,0,sizeof(m)); tot = 1;}
	data operator * (data b) const
	{
		data ret;
		for (int i = 1; i <= tot; i++)
		{
			for (int j = 1; j <= b.tot; j++)
			{
				ret.m[i + j] += (ret.m[i + j - 1] + m[i] * b.m[j]) / r;
				ret.m[i + j - 1] = (ret.m[i + j - 1] + m[i] * b.m[j]) % r;
			}
		}
		ret.tot = tot + b.tot - 1;
		while (ret.m[ret.tot + 1])
		{
			ret.tot++;
			ret.m[ret.tot + 1] += ret.m[ret.tot] / r;
			ret.m[ret.tot] = ret.m[ret.tot] % r;
		}
		return ret;
	}
	data operator & (LL b) const
	{
		data ret;
		for (int i = 1; i <= tot; i++)
		{
			ret.m[i + 1] += (ret.m[i + 1] + m[i] * b) / r;
			ret.m[i] = (ret.m[i] + m[i] * b) % r;
		}
		ret.tot = tot;
		while (ret.m[ret.tot + 1])
		{
			ret.tot++;
			ret.m[ret.tot + 1] += ret.m[ret.tot] / r;
			ret.m[ret.tot] = ret.m[ret.tot] % r;
		}
		return ret;
	}
};

vector<int> e[maxn];
LL n,siz[maxn],s[maxn];
data f[maxn][maxn],ff[maxn],g[maxn];
 
inline LL getint()
{
	LL ret = 0,f = 1;
	char c = getchar();
	while (c < '0' || c > '9')
	{
		if (c == '-') f = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9')
		ret = ret * 10 + c - '0',c = getchar();
	return ret * f;
}
 
inline data hmax(data a,data b)
{
	if (a.tot < b.tot) return b;
	if (b.tot < a.tot) return a;
	for (int i = a.tot; i >= 1; i--)
	{
		if (a.m[i] > b.m[i]) return a;
		if (a.m[i] < b.m[i]) return b;
	}
	return a;
}
 
inline void dp(int u,int fa)
{
	siz[u]++;
	for (int i = 0; i < e[u].size(); i++)
	{
		int v = e[u][i];
		if (v == fa) continue;
		dp(v,u);
		siz[u] += siz[v];
		for (int j = 0; j <= siz[u]; j++) ff[j] = f[u][j];
		for (int j = 0; j <= siz[u] - 1; j++)
			for (int k = 0; k <= min(1ll * j,siz[u] - siz[v]); k++)
				f[u][j] = hmax(f[u][j],f[v][j - k] * ff[k]);
	}
	for (LL j = 0; j <= siz[u] - 1; j++)
		f[u][siz[u]] = hmax(f[u][siz[u]],f[u][j] & (siz[u] - j));
}
 
inline void init()
{
	for (int i = 1; i <= n; i++) 
		f[i][0].m[1] = 1;
}
 
inline void print(data a)
{
	int cnt,pos;
	for (int i = 1; i <= a.tot; i++)
	{
		cnt = 0; pos = (i - 1) * 9;
		while (a.m[i])
		{
			s[++pos] = a.m[i] % 10;
			a.m[i] /= 10;
		}
	}
	for (int i = pos; i >= 1; i--) printf("%d",s[i]);
}
 
int main()
{
	n = getint();
	for (int i = 1; i <= n - 1; i++)
	{
		int u = getint(),v = getint();
		e[u].push_back(v); e[v].push_back(u);
	}
	init();
	dp(1,0);
	print(f[1][n]);
	return 0;
}//注释在开头文本中把难的都打了一遍,其他应该也看得懂