0% found this document useful (0 votes)
14 views

Message

balls (2)

Uploaded by

mingmeisun22
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
14 views

Message

balls (2)

Uploaded by

mingmeisun22
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 5

#include <bits/stdc++.

h>
using namespace std;

#define uid(a, b) uniform_int_distribution<int>(a, b)(rng)

mt19937 rng(std::chrono::system_clock::now().time_since_epoch().count());

long long n;
vector<long long>v[2005];
long long comp=0,aa[2005],vis[2005],sz[2005],cyc[2005],dp[2005]
[2005],fact[2005],cy=0,res=0,t=1;
void f(long long a,long long x,long long p)
{
vis[a]=1,sz[x]++;
for(long long i=0; i<v[a].size(); i++)
{
if(!vis[v[a][i]])
{
f(v[a][i],x,a);
}
else if(v[a][i]!=p||(aa[i]==v[a][i]&&aa[v[a][i]]==a))
{
cyc[x]=0;
}
}
}
long long sol1(int n, vector<int> a)
{
for(int i = 0; i < 2005; ++i) {
v[i].clear();
}
memset(aa, 0, sizeof(aa));
memset(vis, 0, sizeof(vis));
memset(sz, 0, sizeof(sz));
memset(cyc, 0, sizeof(cyc));
memset(dp, 0, sizeof(dp));
memset(fact, 0, sizeof(fact));
comp = cy = res = 0ll;
fact[0]=1;
t = 1;
for(long long i=1; i<=n; i++)
{
fact[i]=(i*fact[i-1])%998244353;
}
for(long long i=0; i<n; i++)
{
aa[i] = a[i + 1];
vis[i]=0,sz[i]=0,cyc[i]=1,aa[i]--;
if(aa[i]!=-2)
{
v[i].push_back(aa[i]);
v[aa[i]].push_back(i);
}
dp[i][0]=0,dp[0][i]=0;
}

dp[0][n]=0,dp[n][0]=0,dp[0][0]=1;
for(long long i=0; i<n; i++)
{
if(!vis[i])
{
f(i,comp++,-1);
cy+=cyc[comp-1];
}
}
for(long long i=1; i<=comp; i++)
{
dp[i][0]=dp[i-1][0];
for(long long j=1; j<=comp; j++)
{
dp[i][j]=(dp[i-1][j]+dp[i-1][j-1]*sz[i-1]*cyc[i-1])%998244353;
}
}
for(long long i=1; i<=comp; i++)
{
long long ans=1,nn=0;
for(long long j=0; j<cy-i; j++)
{
ans=(ans*n)%998244353;
}
res=(res+(((ans*fact[i-1])%998244353)*dp[comp][i]))%998244353;
}
//I think you outputted res here, but the value of res here is different than
the value of sol1 in the stresstester
for(long long i=0; i<cy; i++)
{
t=(t*n)%998244353;
}
return (res+(comp-cy)*t)%998244353;
}

using i64 = long long;


using db = double;
constexpr int N = 1e6 + 50, LOGN = 30;
constexpr i64 P = 998244353, inf = 1e9;
using i64 = long long;
// assume -P <= x < 2P
i64 norm(i64 x) {
if (x < 0) {
x += P;
}
if (x >= P) {
x -= P;
}
return x;
}
template<class T>
T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
struct Z {
i64 x;
Z(i64 x = 0) : x(norm(x % P)) {}
i64 val() const {
return x;
}
Z operator-() const {
return Z(norm(P - x));
}
Z inv() const {
assert(x != 0);
return power(*this, P - 2);
}
Z &operator*=(const Z &rhs) {
x = i64(x) * rhs.x % P;
return *this;
}
Z &operator+=(const Z &rhs) {
x = norm(x + rhs.x);
return *this;
}
Z &operator-=(const Z &rhs) {
x = norm(x - rhs.x);
return *this;
}
Z &operator/=(const Z &rhs) {
return *this *= rhs.inv();
}
friend Z operator*(const Z &lhs, const Z &rhs) {
Z res = lhs;
res *= rhs;
return res;
}
friend Z operator+(const Z &lhs, const Z &rhs) {
Z res = lhs;
res += rhs;
return res;
}
friend Z operator-(const Z &lhs, const Z &rhs) {
Z res = lhs;
res -= rhs;
return res;
}
friend Z operator/(const Z &lhs, const Z &rhs) {
Z res = lhs;
res /= rhs;
return res;
}
friend std::istream &operator>>(std::istream &is, Z &a) {
i64 v;
is >> v;
a = Z(v);
return is;
}
friend std::ostream &operator<<(std::ostream &os, const Z &a) {
return os << a.val();
}
};

struct DSU {
vector<int> fa, sz;
DSU() {}
DSU(int n) { init(n); }
void init(int n){
fa.resize(n);
iota(fa.begin(), fa.end(), 0);
sz.assign(n, 1);
}
int find(int x){ return fa[x] == x ? x : fa[x] = find(fa[x]); }
bool merge(int x, int y) {
x = find(x), y = find(y);
if(x == y) return false;
sz[x] += sz[y];
fa[y] = x;
return true;
}
bool same(int x, int y) { return find(x) == find(y); }
int size(int x) { return sz[find(x)]; }
};

long long solve(int n, vector<int> a){


DSU dsu(n + 1);
for (int i = 1; i <= n; i++) {
if (a[i] != -1)
dsu.merge(i, a[i]);
}
vector<int> cnt(n + 1);
for (int i = 1; i <= n; i++) {
if (a[i] == -1) cnt[dsu.find(i)] = 1;
}
vector<int> sz{0};
Z ans = 0;
vector<Z> pow(n + 1);
pow[0] = 1;
for (int i = 1; i <= n; i++) pow[i] = pow[i - 1] * n;
for (int i = 1; i <= n; i++) {
if (i == dsu.find(i) && cnt[i]) {
sz.emplace_back(dsu.size(i));
}
}
int m = sz.size() - 1;
for (int i = 1; i <= n; i++) {
if (i == dsu.find(i) && !cnt[i]) ans += pow[m];
}
vector<vector<Z>> dp(m + 1, vector<Z> (m + 1));
for (int i = 0; i <= m; i++) dp[i][0] = 1;
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= m; j++) {
dp[i][j] = dp[i - 1][j] + dp[i - 1][j - 1] * max(1, j - 1) *
sz[i];
}
}
for (int i = 1; i <= m; i++) {
ans += dp[m][i] * pow[m - i];
}
return ans.val();

int main() {
int n = 3;
vector<int> a(n + 1, -1);

if(sol1(n, a) != solve(n, a)) {


cout << "all -1" << endl;
return 0;
}

int t = 1000;
while(t--) {
for(int i = 1; i <= n; ++i) {
int val = uid(0, 1);
if(val == 0) {
a[i] = -1;
}
else {
a[i] = uid(1, n);
}
}

long long output1 = sol1(n, a);


long long output2 = solve(n, a);

if(output1 != output2) {
cout << "FOUDN WRONG CASE ON ITERATION " << 1000 - t + 1 << ": " <<
endl;
cout << n << endl;
for(int i = 1; i <= n; ++i) cout << a[i] << " ";
cout << endl;

cout << "got " << output1 << " expected " << output2 << endl;

return 0;
}
}
}

You might also like