0% found this document useful (0 votes)
12 views

Message

balls (2)

Uploaded by

mingmeisun22
Copyright
© © All Rights Reserved
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
12 views

Message

balls (2)

Uploaded by

mingmeisun22
Copyright
© © All Rights Reserved
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 5

#include <bits/stdc++.

h>
using namespace std;

#define uid(a, b) uniform_int_distribution<int>(a, b)(rng)

mt19937 rng(std::chrono::system_clock::now().time_since_epoch().count());

long long n;
vector<long long>v[2005];
long long comp=0,aa[2005],vis[2005],sz[2005],cyc[2005],dp[2005]
[2005],fact[2005],cy=0,res=0,t=1;
void f(long long a,long long x,long long p)
{
vis[a]=1,sz[x]++;
for(long long i=0; i<v[a].size(); i++)
{
if(!vis[v[a][i]])
{
f(v[a][i],x,a);
}
else if(v[a][i]!=p||(aa[i]==v[a][i]&&aa[v[a][i]]==a))
{
cyc[x]=0;
}
}
}
long long sol1(int n, vector<int> a)
{
for(int i = 0; i < 2005; ++i) {
v[i].clear();
}
memset(aa, 0, sizeof(aa));
memset(vis, 0, sizeof(vis));
memset(sz, 0, sizeof(sz));
memset(cyc, 0, sizeof(cyc));
memset(dp, 0, sizeof(dp));
memset(fact, 0, sizeof(fact));
comp = cy = res = 0ll;
fact[0]=1;
t = 1;
for(long long i=1; i<=n; i++)
{
fact[i]=(i*fact[i-1])%998244353;
}
for(long long i=0; i<n; i++)
{
aa[i] = a[i + 1];
vis[i]=0,sz[i]=0,cyc[i]=1,aa[i]--;
if(aa[i]!=-2)
{
v[i].push_back(aa[i]);
v[aa[i]].push_back(i);
}
dp[i][0]=0,dp[0][i]=0;
}

dp[0][n]=0,dp[n][0]=0,dp[0][0]=1;
for(long long i=0; i<n; i++)
{
if(!vis[i])
{
f(i,comp++,-1);
cy+=cyc[comp-1];
}
}
for(long long i=1; i<=comp; i++)
{
dp[i][0]=dp[i-1][0];
for(long long j=1; j<=comp; j++)
{
dp[i][j]=(dp[i-1][j]+dp[i-1][j-1]*sz[i-1]*cyc[i-1])%998244353;
}
}
for(long long i=1; i<=comp; i++)
{
long long ans=1,nn=0;
for(long long j=0; j<cy-i; j++)
{
ans=(ans*n)%998244353;
}
res=(res+(((ans*fact[i-1])%998244353)*dp[comp][i]))%998244353;
}
//I think you outputted res here, but the value of res here is different than
the value of sol1 in the stresstester
for(long long i=0; i<cy; i++)
{
t=(t*n)%998244353;
}
return (res+(comp-cy)*t)%998244353;
}

using i64 = long long;


using db = double;
constexpr int N = 1e6 + 50, LOGN = 30;
constexpr i64 P = 998244353, inf = 1e9;
using i64 = long long;
// assume -P <= x < 2P
i64 norm(i64 x) {
if (x < 0) {
x += P;
}
if (x >= P) {
x -= P;
}
return x;
}
template<class T>
T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
struct Z {
i64 x;
Z(i64 x = 0) : x(norm(x % P)) {}
i64 val() const {
return x;
}
Z operator-() const {
return Z(norm(P - x));
}
Z inv() const {
assert(x != 0);
return power(*this, P - 2);
}
Z &operator*=(const Z &rhs) {
x = i64(x) * rhs.x % P;
return *this;
}
Z &operator+=(const Z &rhs) {
x = norm(x + rhs.x);
return *this;
}
Z &operator-=(const Z &rhs) {
x = norm(x - rhs.x);
return *this;
}
Z &operator/=(const Z &rhs) {
return *this *= rhs.inv();
}
friend Z operator*(const Z &lhs, const Z &rhs) {
Z res = lhs;
res *= rhs;
return res;
}
friend Z operator+(const Z &lhs, const Z &rhs) {
Z res = lhs;
res += rhs;
return res;
}
friend Z operator-(const Z &lhs, const Z &rhs) {
Z res = lhs;
res -= rhs;
return res;
}
friend Z operator/(const Z &lhs, const Z &rhs) {
Z res = lhs;
res /= rhs;
return res;
}
friend std::istream &operator>>(std::istream &is, Z &a) {
i64 v;
is >> v;
a = Z(v);
return is;
}
friend std::ostream &operator<<(std::ostream &os, const Z &a) {
return os << a.val();
}
};

struct DSU {
vector<int> fa, sz;
DSU() {}
DSU(int n) { init(n); }
void init(int n){
fa.resize(n);
iota(fa.begin(), fa.end(), 0);
sz.assign(n, 1);
}
int find(int x){ return fa[x] == x ? x : fa[x] = find(fa[x]); }
bool merge(int x, int y) {
x = find(x), y = find(y);
if(x == y) return false;
sz[x] += sz[y];
fa[y] = x;
return true;
}
bool same(int x, int y) { return find(x) == find(y); }
int size(int x) { return sz[find(x)]; }
};

long long solve(int n, vector<int> a){


DSU dsu(n + 1);
for (int i = 1; i <= n; i++) {
if (a[i] != -1)
dsu.merge(i, a[i]);
}
vector<int> cnt(n + 1);
for (int i = 1; i <= n; i++) {
if (a[i] == -1) cnt[dsu.find(i)] = 1;
}
vector<int> sz{0};
Z ans = 0;
vector<Z> pow(n + 1);
pow[0] = 1;
for (int i = 1; i <= n; i++) pow[i] = pow[i - 1] * n;
for (int i = 1; i <= n; i++) {
if (i == dsu.find(i) && cnt[i]) {
sz.emplace_back(dsu.size(i));
}
}
int m = sz.size() - 1;
for (int i = 1; i <= n; i++) {
if (i == dsu.find(i) && !cnt[i]) ans += pow[m];
}
vector<vector<Z>> dp(m + 1, vector<Z> (m + 1));
for (int i = 0; i <= m; i++) dp[i][0] = 1;
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= m; j++) {
dp[i][j] = dp[i - 1][j] + dp[i - 1][j - 1] * max(1, j - 1) *
sz[i];
}
}
for (int i = 1; i <= m; i++) {
ans += dp[m][i] * pow[m - i];
}
return ans.val();

int main() {
int n = 3;
vector<int> a(n + 1, -1);

if(sol1(n, a) != solve(n, a)) {


cout << "all -1" << endl;
return 0;
}

int t = 1000;
while(t--) {
for(int i = 1; i <= n; ++i) {
int val = uid(0, 1);
if(val == 0) {
a[i] = -1;
}
else {
a[i] = uid(1, n);
}
}

long long output1 = sol1(n, a);


long long output2 = solve(n, a);

if(output1 != output2) {
cout << "FOUDN WRONG CASE ON ITERATION " << 1000 - t + 1 << ": " <<
endl;
cout << n << endl;
for(int i = 1; i <= n; ++i) cout << a[i] << " ";
cout << endl;

cout << "got " << output1 << " expected " << output2 << endl;

return 0;
}
}
}

You might also like