Message
Message
h>
using namespace std;
mt19937 rng(std::chrono::system_clock::now().time_since_epoch().count());
long long n;
vector<long long>v[2005];
long long comp=0,aa[2005],vis[2005],sz[2005],cyc[2005],dp[2005]
[2005],fact[2005],cy=0,res=0,t=1;
void f(long long a,long long x,long long p)
{
vis[a]=1,sz[x]++;
for(long long i=0; i<v[a].size(); i++)
{
if(!vis[v[a][i]])
{
f(v[a][i],x,a);
}
else if(v[a][i]!=p||(aa[i]==v[a][i]&&aa[v[a][i]]==a))
{
cyc[x]=0;
}
}
}
long long sol1(int n, vector<int> a)
{
for(int i = 0; i < 2005; ++i) {
v[i].clear();
}
memset(aa, 0, sizeof(aa));
memset(vis, 0, sizeof(vis));
memset(sz, 0, sizeof(sz));
memset(cyc, 0, sizeof(cyc));
memset(dp, 0, sizeof(dp));
memset(fact, 0, sizeof(fact));
comp = cy = res = 0ll;
fact[0]=1;
t = 1;
for(long long i=1; i<=n; i++)
{
fact[i]=(i*fact[i-1])%998244353;
}
for(long long i=0; i<n; i++)
{
aa[i] = a[i + 1];
vis[i]=0,sz[i]=0,cyc[i]=1,aa[i]--;
if(aa[i]!=-2)
{
v[i].push_back(aa[i]);
v[aa[i]].push_back(i);
}
dp[i][0]=0,dp[0][i]=0;
}
dp[0][n]=0,dp[n][0]=0,dp[0][0]=1;
for(long long i=0; i<n; i++)
{
if(!vis[i])
{
f(i,comp++,-1);
cy+=cyc[comp-1];
}
}
for(long long i=1; i<=comp; i++)
{
dp[i][0]=dp[i-1][0];
for(long long j=1; j<=comp; j++)
{
dp[i][j]=(dp[i-1][j]+dp[i-1][j-1]*sz[i-1]*cyc[i-1])%998244353;
}
}
for(long long i=1; i<=comp; i++)
{
long long ans=1,nn=0;
for(long long j=0; j<cy-i; j++)
{
ans=(ans*n)%998244353;
}
res=(res+(((ans*fact[i-1])%998244353)*dp[comp][i]))%998244353;
}
//I think you outputted res here, but the value of res here is different than
the value of sol1 in the stresstester
for(long long i=0; i<cy; i++)
{
t=(t*n)%998244353;
}
return (res+(comp-cy)*t)%998244353;
}
struct DSU {
vector<int> fa, sz;
DSU() {}
DSU(int n) { init(n); }
void init(int n){
fa.resize(n);
iota(fa.begin(), fa.end(), 0);
sz.assign(n, 1);
}
int find(int x){ return fa[x] == x ? x : fa[x] = find(fa[x]); }
bool merge(int x, int y) {
x = find(x), y = find(y);
if(x == y) return false;
sz[x] += sz[y];
fa[y] = x;
return true;
}
bool same(int x, int y) { return find(x) == find(y); }
int size(int x) { return sz[find(x)]; }
};
int main() {
int n = 3;
vector<int> a(n + 1, -1);
int t = 1000;
while(t--) {
for(int i = 1; i <= n; ++i) {
int val = uid(0, 1);
if(val == 0) {
a[i] = -1;
}
else {
a[i] = uid(1, n);
}
}
if(output1 != output2) {
cout << "FOUDN WRONG CASE ON ITERATION " << 1000 - t + 1 << ": " <<
endl;
cout << n << endl;
for(int i = 1; i <= n; ++i) cout << a[i] << " ";
cout << endl;
cout << "got " << output1 << " expected " << output2 << endl;
return 0;
}
}
}