diff --git a/pg_wait_sampling.c b/pg_wait_sampling.c index a35fb94..e165a6a 100644 --- a/pg_wait_sampling.c +++ b/pg_wait_sampling.c @@ -76,13 +76,7 @@ static PlannedStmt *pgws_planner_hook(Query *parse, const char *query_string, #endif int cursorOptions, ParamListInfo boundParams); -static -#if PG_VERSION_NUM >= 180000 -bool -#else -void -#endif -pgws_ExecutorStart(QueryDesc *queryDesc, int eflags); +static void pgws_ExecutorStart(QueryDesc *queryDesc, int eflags); static void pgws_ExecutorRun(QueryDesc *queryDesc, ScanDirection direction, uint64 count @@ -655,6 +649,10 @@ receive_array(SHMRequest request, Size item_size, Size *count) pgws_collector_hdr->request = request; LockRelease(&collectorTag, ExclusiveLock, false); + /* + * Check that the collector was started to avoid NULL + * pointer dereference. + */ if (!pgws_collector_hdr->latch) ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), errmsg("pg_wait_sampling collector wasn't started"))); @@ -825,6 +823,14 @@ pg_wait_sampling_reset_profile(PG_FUNCTION_ARGS) pgws_collector_hdr->request = PROFILE_RESET; LockRelease(&collectorTag, ExclusiveLock, false); + /* + * Check that the collector was started to avoid NULL + * pointer dereference. + */ + if (!pgws_collector_hdr->latch) + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("pg_wait_sampling collector wasn't started"))); + SetLatch(pgws_collector_hdr->latch); LockRelease(&queueTag, ExclusiveLock, false); @@ -982,12 +988,7 @@ pgws_planner_hook(Query *parse, /* * ExecutorStart hook: save queryId for collector */ -static -#if PG_VERSION_NUM >= 180000 -bool -#else -void -#endif +static void pgws_ExecutorStart(QueryDesc *queryDesc, int eflags) { int i = MyProc - ProcGlobal->allProcs; @@ -995,9 +996,9 @@ pgws_ExecutorStart(QueryDesc *queryDesc, int eflags) if (pgws_enabled(nesting_level)) pgws_proc_queryids[i] = queryDesc->plannedstmt->queryId; if (prev_ExecutorStart) - return prev_ExecutorStart(queryDesc, eflags); + prev_ExecutorStart(queryDesc, eflags); else - return standard_ExecutorStart(queryDesc, eflags); + standard_ExecutorStart(queryDesc, eflags); } static void