From d4490664ec80f52d23c4345eec5771764bcdbb17 Mon Sep 17 00:00:00 2001
From: Antonin Houska <ah@cybertec.at>
Date: Wed, 15 Mar 2023 04:21:01 +0100
Subject: [PATCH 1/2] Move some code into functions.

This is only a preparation for a patch that introduces USAGE privilege on
publications. It should make the following diff a little bit easier to read.
---
 src/backend/catalog/pg_publication.c        | 236 +++++++++++++++++---
 src/backend/commands/copy.c                 |  81 +------
 src/backend/commands/copyto.c               |  89 ++++++++
 src/backend/replication/pgoutput/pgoutput.c | 139 +-----------
 src/include/catalog/pg_publication.h        |   6 +
 src/include/commands/copy.h                 |   2 +
 6 files changed, 308 insertions(+), 245 deletions(-)

diff --git a/src/backend/catalog/pg_publication.c b/src/backend/catalog/pg_publication.c
index a98fcad421..7f6024b7a5 100644
--- a/src/backend/catalog/pg_publication.c
+++ b/src/backend/catalog/pg_publication.c
@@ -1025,6 +1025,208 @@ GetPublicationByName(const char *pubname, bool missing_ok)
 	return OidIsValid(oid) ? GetPublication(oid) : NULL;
 }
 
+/*
+ * Get the mapping for given publication and relation.
+ */
+void
+GetPublicationRelationMapping(Oid pubid, Oid relid,
+							  Datum *attrs, bool *attrs_isnull,
+							  Datum *qual, bool *qual_isnull)
+{
+	Publication *publication;
+	HeapTuple	pubtuple = NULL;
+	Oid			schemaid = get_rel_namespace(relid);
+
+	publication = GetPublication(pubid);
+
+	/*
+	 * We don't consider row filters or column lists for FOR ALL TABLES or
+	 * FOR TABLES IN SCHEMA publications.
+	 */
+	if (!publication->alltables &&
+		!SearchSysCacheExists2(PUBLICATIONNAMESPACEMAP,
+							   ObjectIdGetDatum(schemaid),
+							   ObjectIdGetDatum(publication->oid)))
+		pubtuple = SearchSysCacheCopy2(PUBLICATIONRELMAP,
+									   ObjectIdGetDatum(relid),
+									   ObjectIdGetDatum(publication->oid));
+
+	if (HeapTupleIsValid(pubtuple))
+	{
+		/* Lookup the column list attribute. */
+		*attrs = SysCacheGetAttr(PUBLICATIONRELMAP, pubtuple,
+								 Anum_pg_publication_rel_prattrs,
+								 attrs_isnull);
+
+		/* Null indicates no filter. */
+		*qual = SysCacheGetAttr(PUBLICATIONRELMAP, pubtuple,
+								Anum_pg_publication_rel_prqual,
+								qual_isnull);
+	}
+	else
+	{
+		*attrs_isnull = true;
+		*qual_isnull = true;
+	}
+}
+/*
+ * Pick those publications from a list which should actually be used to
+ * publish given relation and return them.
+ *
+ * If publish_as_relid_p is passed, the relation whose tuple descriptor should
+ * be used to publish the data is stored in *publish_as_relid_p.
+ *
+ * If pubactions is passed, update the structure according to the matching
+ * publications.
+ */
+List *
+GetEffectiveRelationPublications(Oid relid, List *publications,
+								 Oid *publish_as_relid_p,
+								 PublicationActions *pubactions)
+{
+	Oid			schemaId = get_rel_namespace(relid);
+	List	   *pubids = GetRelationPublications(relid);
+	/*
+	 * We don't acquire a lock on the namespace system table as we build the
+	 * cache entry using a historic snapshot and all the later changes are
+	 * absorbed while decoding WAL.
+	 */
+	List	   *schemaPubids = GetSchemaPublications(schemaId);
+	ListCell   *lc;
+	Oid			publish_as_relid = relid;
+	int			publish_ancestor_level = 0;
+	bool		am_partition = get_rel_relispartition(relid);
+	char		relkind = get_rel_relkind(relid);
+	List	   *rel_publications = NIL;
+
+	foreach(lc, publications)
+	{
+		Publication *pub = lfirst(lc);
+		bool		publish = false;
+
+		/*
+		 * Under what relid should we publish changes in this publication?
+		 * We'll use the top-most relid across all publications. Also track
+		 * the ancestor level for this publication.
+		 */
+		Oid	pub_relid = relid;
+		int	ancestor_level = 0;
+
+		/*
+		 * If this is a FOR ALL TABLES publication, pick the partition root
+		 * and set the ancestor level accordingly.
+		 */
+		if (pub->alltables)
+		{
+			publish = true;
+			if (pub->pubviaroot && am_partition)
+			{
+				List	   *ancestors = get_partition_ancestors(relid);
+
+				pub_relid = llast_oid(ancestors);
+				ancestor_level = list_length(ancestors);
+			}
+		}
+
+		if (!publish)
+		{
+			bool		ancestor_published = false;
+
+			/*
+			 * For a partition, check if any of the ancestors are published.
+			 * If so, note down the topmost ancestor that is published via
+			 * this publication, which will be used as the relation via which
+			 * to publish the partition's changes.
+			 */
+			if (am_partition)
+			{
+				Oid			ancestor;
+				int			level;
+				List	   *ancestors = get_partition_ancestors(relid);
+
+				ancestor = GetTopMostAncestorInPublication(pub->oid,
+														   ancestors,
+														   &level);
+
+				if (ancestor != InvalidOid)
+				{
+					ancestor_published = true;
+					if (pub->pubviaroot)
+					{
+						pub_relid = ancestor;
+						ancestor_level = level;
+					}
+				}
+			}
+
+			if (list_member_oid(pubids, pub->oid) ||
+				list_member_oid(schemaPubids, pub->oid) ||
+				ancestor_published)
+				publish = true;
+		}
+
+		/*
+		 * If the relation is to be published, determine actions to publish,
+		 * and list of columns, if appropriate.
+		 *
+		 * Don't publish changes for partitioned tables, because publishing
+		 * those of its partitions suffices, unless partition changes won't be
+		 * published due to pubviaroot being set.
+		 */
+		if (publish &&
+			(relkind != RELKIND_PARTITIONED_TABLE || pub->pubviaroot))
+		{
+			if (pubactions)
+			{
+				pubactions->pubinsert |= pub->pubactions.pubinsert;
+				pubactions->pubupdate |= pub->pubactions.pubupdate;
+				pubactions->pubdelete |= pub->pubactions.pubdelete;
+				pubactions->pubtruncate |= pub->pubactions.pubtruncate;
+			}
+
+			/*
+			 * We want to publish the changes as the top-most ancestor across
+			 * all publications. So we need to check if the already calculated
+			 * level is higher than the new one. If yes, we can ignore the new
+			 * value (as it's a child). Otherwise the new value is an
+			 * ancestor, so we keep it.
+			 */
+			if (publish_ancestor_level > ancestor_level)
+				continue;
+
+			/*
+			 * If we found an ancestor higher up in the tree, discard the list
+			 * of publications through which we replicate it, and use the new
+			 * ancestor.
+			 */
+			if (publish_ancestor_level < ancestor_level)
+			{
+				publish_as_relid = pub_relid;
+				publish_ancestor_level = ancestor_level;
+
+				/* reset the publication list for this relation */
+				rel_publications = NIL;
+			}
+			else
+			{
+				/* Same ancestor level, has to be the same OID. */
+				Assert(publish_as_relid == pub_relid);
+			}
+
+			/* Track publications for this ancestor. */
+			rel_publications = lappend(rel_publications, pub);
+		}
+	}
+
+	list_free(pubids);
+	list_free(schemaPubids);
+
+	if (publish_as_relid_p)
+		*publish_as_relid_p = publish_as_relid;
+
+	return rel_publications;
+}
+
 /*
  * Returns information of tables in a publication.
  */
@@ -1108,10 +1310,8 @@ pg_get_publication_tables(PG_FUNCTION_ARGS)
 
 	if (funcctx->call_cntr < list_length(tables))
 	{
-		HeapTuple	pubtuple = NULL;
 		HeapTuple	rettuple;
 		Oid			relid = list_nth_oid(tables, funcctx->call_cntr);
-		Oid			schemaid = get_rel_namespace(relid);
 		Datum		values[NUM_PUBLICATION_TABLES_ELEM] = {0};
 		bool		nulls[NUM_PUBLICATION_TABLES_ELEM] = {0};
 
@@ -1123,35 +1323,9 @@ pg_get_publication_tables(PG_FUNCTION_ARGS)
 
 		values[0] = ObjectIdGetDatum(relid);
 
-		/*
-		 * We don't consider row filters or column lists for FOR ALL TABLES or
-		 * FOR TABLES IN SCHEMA publications.
-		 */
-		if (!publication->alltables &&
-			!SearchSysCacheExists2(PUBLICATIONNAMESPACEMAP,
-								   ObjectIdGetDatum(schemaid),
-								   ObjectIdGetDatum(publication->oid)))
-			pubtuple = SearchSysCacheCopy2(PUBLICATIONRELMAP,
-										   ObjectIdGetDatum(relid),
-										   ObjectIdGetDatum(publication->oid));
-
-		if (HeapTupleIsValid(pubtuple))
-		{
-			/* Lookup the column list attribute. */
-			values[1] = SysCacheGetAttr(PUBLICATIONRELMAP, pubtuple,
-										Anum_pg_publication_rel_prattrs,
-										&(nulls[1]));
-
-			/* Null indicates no filter. */
-			values[2] = SysCacheGetAttr(PUBLICATIONRELMAP, pubtuple,
-										Anum_pg_publication_rel_prqual,
-										&(nulls[2]));
-		}
-		else
-		{
-			nulls[1] = true;
-			nulls[2] = true;
-		}
+		GetPublicationRelationMapping(publication->oid, relid,
+									  &values[1], &nulls[1],
+									  &values[2], &nulls[2]);
 
 		/* Show all columns when the column list is not specified. */
 		if (nulls[1] == true)
diff --git a/src/backend/commands/copy.c b/src/backend/commands/copy.c
index 167d31a2d9..8edc2c19f6 100644
--- a/src/backend/commands/copy.c
+++ b/src/backend/commands/copy.c
@@ -177,92 +177,13 @@ DoCopy(ParseState *pstate, const CopyStmt *stmt,
 		 */
 		if (check_enable_rls(relid, InvalidOid, false) == RLS_ENABLED)
 		{
-			SelectStmt *select;
-			ColumnRef  *cr;
-			ResTarget  *target;
-			RangeVar   *from;
-			List	   *targetList = NIL;
-
 			if (is_from)
 				ereport(ERROR,
 						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 						 errmsg("COPY FROM not supported with row-level security"),
 						 errhint("Use INSERT statements instead.")));
 
-			/*
-			 * Build target list
-			 *
-			 * If no columns are specified in the attribute list of the COPY
-			 * command, then the target list is 'all' columns. Therefore, '*'
-			 * should be used as the target list for the resulting SELECT
-			 * statement.
-			 *
-			 * In the case that columns are specified in the attribute list,
-			 * create a ColumnRef and ResTarget for each column and add them
-			 * to the target list for the resulting SELECT statement.
-			 */
-			if (!stmt->attlist)
-			{
-				cr = makeNode(ColumnRef);
-				cr->fields = list_make1(makeNode(A_Star));
-				cr->location = -1;
-
-				target = makeNode(ResTarget);
-				target->name = NULL;
-				target->indirection = NIL;
-				target->val = (Node *) cr;
-				target->location = -1;
-
-				targetList = list_make1(target);
-			}
-			else
-			{
-				ListCell   *lc;
-
-				foreach(lc, stmt->attlist)
-				{
-					/*
-					 * Build the ColumnRef for each column.  The ColumnRef
-					 * 'fields' property is a String node that corresponds to
-					 * the column name respectively.
-					 */
-					cr = makeNode(ColumnRef);
-					cr->fields = list_make1(lfirst(lc));
-					cr->location = -1;
-
-					/* Build the ResTarget and add the ColumnRef to it. */
-					target = makeNode(ResTarget);
-					target->name = NULL;
-					target->indirection = NIL;
-					target->val = (Node *) cr;
-					target->location = -1;
-
-					/* Add each column to the SELECT statement's target list */
-					targetList = lappend(targetList, target);
-				}
-			}
-
-			/*
-			 * Build RangeVar for from clause, fully qualified based on the
-			 * relation which we have opened and locked.  Use "ONLY" so that
-			 * COPY retrieves rows from only the target table not any
-			 * inheritance children, the same as when RLS doesn't apply.
-			 */
-			from = makeRangeVar(get_namespace_name(RelationGetNamespace(rel)),
-								pstrdup(RelationGetRelationName(rel)),
-								-1);
-			from->inh = false;	/* apply ONLY */
-
-			/* Build query */
-			select = makeNode(SelectStmt);
-			select->targetList = targetList;
-			select->fromClause = list_make1(from);
-
-			query = makeNode(RawStmt);
-			query->stmt = (Node *) select;
-			query->stmt_location = stmt_location;
-			query->stmt_len = stmt_len;
-
+			query = CreateCopyToQuery(stmt, rel, stmt_location, stmt_len);
 			/*
 			 * Close the relation for now, but keep the lock on it to prevent
 			 * changes between now and when we start the query-based COPY.
diff --git a/src/backend/commands/copyto.c b/src/backend/commands/copyto.c
index beea1ac687..af0cdef158 100644
--- a/src/backend/commands/copyto.c
+++ b/src/backend/commands/copyto.c
@@ -32,6 +32,7 @@
 #include "libpq/pqformat.h"
 #include "mb/pg_wchar.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "optimizer/optimizer.h"
 #include "pgstat.h"
 #include "rewrite/rewriteHandler.h"
@@ -339,6 +340,94 @@ EndCopy(CopyToState cstate)
 	pfree(cstate);
 }
 
+/*
+ * Turn "COPY table_name TO" form into "COPY (query) TO".
+ */
+RawStmt *
+CreateCopyToQuery(const CopyStmt *stmt, Relation rel, int stmt_location,
+				  int stmt_len)
+{
+	SelectStmt *select;
+	ColumnRef  *cr;
+	ResTarget  *target;
+	RangeVar   *from;
+	List	   *targetList = NIL;
+	RawStmt    *query = NULL;
+
+	/*
+	 * Build target list
+	 *
+	 * If no columns are specified in the attribute list of the COPY command,
+	 * then the target list is 'all' columns. Therefore, '*' should be used as
+	 * the target list for the resulting SELECT statement.
+	 *
+	 * In the case that columns are specified in the attribute list, create a
+	 * ColumnRef and ResTarget for each column and add them to the target list
+	 * for the resulting SELECT statement.
+	 */
+	if (!stmt->attlist)
+	{
+		cr = makeNode(ColumnRef);
+		cr->fields = list_make1(makeNode(A_Star));
+		cr->location = -1;
+
+		target = makeNode(ResTarget);
+		target->name = NULL;
+		target->indirection = NIL;
+		target->val = (Node *) cr;
+		target->location = -1;
+
+		targetList = list_make1(target);
+	}
+	else
+	{
+		ListCell   *lc;
+
+		foreach(lc, stmt->attlist)
+		{
+			/*
+			 * Build the ColumnRef for each column.  The ColumnRef 'fields'
+			 * property is a String node that corresponds to the column name
+			 * respectively.
+			 */
+			cr = makeNode(ColumnRef);
+			cr->fields = list_make1(lfirst(lc));
+			cr->location = -1;
+
+			/* Build the ResTarget and add the ColumnRef to it. */
+			target = makeNode(ResTarget);
+			target->name = NULL;
+			target->indirection = NIL;
+			target->val = (Node *) cr;
+			target->location = -1;
+
+			/* Add each column to the SELECT statement's target list */
+			targetList = lappend(targetList, target);
+		}
+	}
+
+	/*
+	 * Build RangeVar for from clause, fully qualified based on the relation
+	 * which we have opened and locked.
+	 */
+	from = makeRangeVar(get_namespace_name(RelationGetNamespace(rel)),
+						pstrdup(RelationGetRelationName(rel)),
+						-1);
+	from->inh = false;	/* apply ONLY */
+
+	/* Build query */
+	select = makeNode(SelectStmt);
+	select->targetList = targetList;
+	select->fromClause = list_make1(from);
+
+	query = makeNode(RawStmt);
+	query->stmt = (Node *) select;
+	query->stmt_location = stmt_location;
+	query->stmt_len = stmt_len;
+
+	return query;
+}
+
 /*
  * Setup CopyToState to read tuples from a table or a query for COPY TO.
  *
diff --git a/src/backend/replication/pgoutput/pgoutput.c b/src/backend/replication/pgoutput/pgoutput.c
index 00a2d73dab..21b8b2944e 100644
--- a/src/backend/replication/pgoutput/pgoutput.c
+++ b/src/backend/replication/pgoutput/pgoutput.c
@@ -2063,21 +2063,7 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 	/* Validate the entry */
 	if (!entry->replicate_valid)
 	{
-		Oid			schemaId = get_rel_namespace(relid);
-		List	   *pubids = GetRelationPublications(relid);
-
-		/*
-		 * We don't acquire a lock on the namespace system table as we build
-		 * the cache entry using a historic snapshot and all the later changes
-		 * are absorbed while decoding WAL.
-		 */
-		List	   *schemaPubids = GetSchemaPublications(schemaId);
-		ListCell   *lc;
-		Oid			publish_as_relid = relid;
-		int			publish_ancestor_level = 0;
-		bool		am_partition = get_rel_relispartition(relid);
-		char		relkind = get_rel_relkind(relid);
-		List	   *rel_publications = NIL;
+		List	*rel_publications;
 
 		/* Reload publications if needed before use. */
 		if (!publications_valid)
@@ -2140,123 +2126,10 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 		 * but here we only need to consider ones that the subscriber
 		 * requested.
 		 */
-		foreach(lc, data->publications)
-		{
-			Publication *pub = lfirst(lc);
-			bool		publish = false;
-
-			/*
-			 * Under what relid should we publish changes in this publication?
-			 * We'll use the top-most relid across all publications. Also
-			 * track the ancestor level for this publication.
-			 */
-			Oid			pub_relid = relid;
-			int			ancestor_level = 0;
-
-			/*
-			 * If this is a FOR ALL TABLES publication, pick the partition
-			 * root and set the ancestor level accordingly.
-			 */
-			if (pub->alltables)
-			{
-				publish = true;
-				if (pub->pubviaroot && am_partition)
-				{
-					List	   *ancestors = get_partition_ancestors(relid);
-
-					pub_relid = llast_oid(ancestors);
-					ancestor_level = list_length(ancestors);
-				}
-			}
-
-			if (!publish)
-			{
-				bool		ancestor_published = false;
-
-				/*
-				 * For a partition, check if any of the ancestors are
-				 * published.  If so, note down the topmost ancestor that is
-				 * published via this publication, which will be used as the
-				 * relation via which to publish the partition's changes.
-				 */
-				if (am_partition)
-				{
-					Oid			ancestor;
-					int			level;
-					List	   *ancestors = get_partition_ancestors(relid);
-
-					ancestor = GetTopMostAncestorInPublication(pub->oid,
-															   ancestors,
-															   &level);
-
-					if (ancestor != InvalidOid)
-					{
-						ancestor_published = true;
-						if (pub->pubviaroot)
-						{
-							pub_relid = ancestor;
-							ancestor_level = level;
-						}
-					}
-				}
-
-				if (list_member_oid(pubids, pub->oid) ||
-					list_member_oid(schemaPubids, pub->oid) ||
-					ancestor_published)
-					publish = true;
-			}
-
-			/*
-			 * If the relation is to be published, determine actions to
-			 * publish, and list of columns, if appropriate.
-			 *
-			 * Don't publish changes for partitioned tables, because
-			 * publishing those of its partitions suffices, unless partition
-			 * changes won't be published due to pubviaroot being set.
-			 */
-			if (publish &&
-				(relkind != RELKIND_PARTITIONED_TABLE || pub->pubviaroot))
-			{
-				entry->pubactions.pubinsert |= pub->pubactions.pubinsert;
-				entry->pubactions.pubupdate |= pub->pubactions.pubupdate;
-				entry->pubactions.pubdelete |= pub->pubactions.pubdelete;
-				entry->pubactions.pubtruncate |= pub->pubactions.pubtruncate;
-
-				/*
-				 * We want to publish the changes as the top-most ancestor
-				 * across all publications. So we need to check if the already
-				 * calculated level is higher than the new one. If yes, we can
-				 * ignore the new value (as it's a child). Otherwise the new
-				 * value is an ancestor, so we keep it.
-				 */
-				if (publish_ancestor_level > ancestor_level)
-					continue;
-
-				/*
-				 * If we found an ancestor higher up in the tree, discard the
-				 * list of publications through which we replicate it, and use
-				 * the new ancestor.
-				 */
-				if (publish_ancestor_level < ancestor_level)
-				{
-					publish_as_relid = pub_relid;
-					publish_ancestor_level = ancestor_level;
-
-					/* reset the publication list for this relation */
-					rel_publications = NIL;
-				}
-				else
-				{
-					/* Same ancestor level, has to be the same OID. */
-					Assert(publish_as_relid == pub_relid);
-				}
-
-				/* Track publications for this ancestor. */
-				rel_publications = lappend(rel_publications, pub);
-			}
-		}
-
-		entry->publish_as_relid = publish_as_relid;
+		rel_publications = GetEffectiveRelationPublications(relid,
+															data->publications,
+															&entry->publish_as_relid,
+															&entry->pubactions);
 
 		/*
 		 * Initialize the tuple slot, map, and row filter. These are only used
@@ -2275,8 +2148,6 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 			pgoutput_column_list_init(data, rel_publications, entry);
 		}
 
-		list_free(pubids);
-		list_free(schemaPubids);
 		list_free(rel_publications);
 
 		entry->replicate_valid = true;
diff --git a/src/include/catalog/pg_publication.h b/src/include/catalog/pg_publication.h
index 6ecaa2a01e..dab5bc8444 100644
--- a/src/include/catalog/pg_publication.h
+++ b/src/include/catalog/pg_publication.h
@@ -113,6 +113,12 @@ typedef struct PublicationRelInfo
 extern Publication *GetPublication(Oid pubid);
 extern Publication *GetPublicationByName(const char *pubname, bool missing_ok);
 extern List *GetRelationPublications(Oid relid);
+extern void GetPublicationRelationMapping(Oid pubid, Oid relid,
+										  Datum *attrs, bool *attrs_isnull,
+										  Datum *qual, bool *qual_isnull);
+extern List *GetEffectiveRelationPublications(Oid relid, List *publications,
+											  Oid *publish_as_relid_p,
+											  PublicationActions *pubactions);
 
 /*---------
  * Expected values for pub_partopt parameter of GetRelationPublications(),
diff --git a/src/include/commands/copy.h b/src/include/commands/copy.h
index 33175868f6..774b835251 100644
--- a/src/include/commands/copy.h
+++ b/src/include/commands/copy.h
@@ -92,6 +92,8 @@ extern DestReceiver *CreateCopyDestReceiver(void);
 /*
  * internal prototypes
  */
+extern RawStmt *CreateCopyToQuery(const CopyStmt *stmt, Relation rel,
+								  int stmt_location, int stmt_len);
 extern CopyToState BeginCopyTo(ParseState *pstate, Relation rel, RawStmt *raw_query,
 							   Oid queryRelId, const char *filename, bool is_program,
 							   copy_data_dest_cb data_dest_cb, List *attnamelist, List *options);
-- 
2.31.1

