Skip to content

[ES|QL] Add a standard deviation function #116531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Nov 22, 2024

Conversation

limotova
Copy link
Contributor

@limotova limotova commented Nov 8, 2024

Uses Welford's online algorithm, as well as the parallel version, to
calculate standard deviation.

Uses Welford's online algorithm, as well as the parallel version, to
calculate standard deviation.
Copy link
Contributor

github-actions bot commented Nov 8, 2024

Documentation preview:

final long count = state.count();
final double m2 = state.m2();
if (count == 0 || Double.isFinite(m2) == false) {
return driverContext.blockFactory().newConstantNullBlock(1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the result is infinity or NaN I set it to return null, but I'm not sure if there should be a warning or something similar printed (or where that would best be done)?


import static java.util.Collections.emptyList;

public class StdDeviation extends AggregateFunction implements ToAggregator {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what best to name this. There were a few options: StdDeviation, StandardDeviation, or Stdev (or maybe even Stddev)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think StandardDeviation is better, although it's an internal name, and we can change it anytime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about for the name of the ES|QL function? Right now it's std_deviation, I feel like std_dev maybe works better? I worry that standard_deviation is kind of long

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think std_dev is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the internal name be changed to StdDev to match?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer StandardDeviation for the class name, but feel free to choose whichever name you prefer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be easier to change both; I tried changing only the function name but it looks like the name of the tests class is used for generating the docs so I think it might be simpler to use StdDev for the class name as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can annotate the test class with the function name if it does not match cleanly. For example, see this class SpatialIntersectsTest which tests ST_INTERSECTS: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersectsTests.java#L23

@limotova limotova requested a review from dnhatn November 8, 2024 23:20
@dnhatn
Copy link
Member

dnhatn commented Nov 14, 2024

@limotova I extracted values from the serverless test and combined them to reproduce the test failure. Using three batches, the final value is 0.23282704603226836, while with a single batch, the value is 0.22797190865484734. Could you check if this discrepancy is acceptable?

    public void testBasic() {
        double[] v1 = {1.97, 2.0, 1.57, 1.48, 1.77};
        double[] v2 = {2.1, 1.74, 1.96, 1.42, 1.59, 2.07, 1.81, 1.59, 1.44, 2.03, 1.81};
        double[] v3 = {2.03, 1.54, 1.55};

        WelfordAlgorithm a1 = new WelfordAlgorithm();
        for (double v : v1) {
            a1.add(v);
        }

        WelfordAlgorithm a2 = new WelfordAlgorithm();
        for (double v : v2) {
            a2.add(v);
        }

        WelfordAlgorithm a3 = new WelfordAlgorithm();
        for (double v : v3) {
            a3.add(v);
        }

        WelfordAlgorithm merged = new WelfordAlgorithm();
        for (WelfordAlgorithm a : List.of(a3, a2, a1)) {
            merged.add(a.mean(), a.m2(), a.count());
        }

        System.err.println("--> merged = " + merged.evaluate());

        WelfordAlgorithm single = new WelfordAlgorithm();
        for (double v : v1) {
            single.add(v);
        }
        for (double v : v2) {
            single.add(v);
        }
        for (double v : v3) {
            single.add(v);
        }
        System.err.println("--> single = " + single.evaluate());
    }


import static java.util.Collections.emptyList;

public class StdDeviation extends AggregateFunction implements ToAggregator {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think StandardDeviation is better, although it's an internal name, and we can change it anytime.

}

public double evaluate() {
return count < 2 ? 0 : Math.sqrt(m2 / count);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be count-1 instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with count because I believe we use the population standard deviation elsewhere but I can change it if we'd prefer sample standard deviation?

* <a href="https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm">
* Parallel algorithm</a>
*/
public final class WelfordAlgorithm {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we fold this class into the StdDeviationStates#SingleState?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if we wanted to use it elsewhere (like if we wanted to support variance or have both sample and population standard deviation)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave it as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like the kind of thing that could move moved to a common place like libs/, but while there is only one usage, it might as well stay here for now.

Copy link
Member

@dnhatn dnhatn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two optional comments, but overall, I think the PR is ready. LGTM! However, I’d love to have another review from the ES|QL team. Great work, thanks Larisa!

Copy link
Contributor

@ivancea ivancea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Added some suggestions around tests mostly

@limotova limotova requested a review from ivancea November 21, 2024 01:01
@astefan
Copy link
Contributor

astefan commented Nov 21, 2024

Flyby feedback:

  • std_dev(first_name) fails with an unfriendly error message: org.elasticsearch.xpack.esql.EsqlIllegalArgumentException: Cannot find intermediate state for: AggDef[aggClazz=class org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev, type=BytesRef, extra=, grouping=false]. This probably comes from the lack of resolveType() method implementation.
  • std_dev(salary_change) I think deserves a test, meaning the aggregation function applied on a multi-value field. You are using salary_change in a test, but you apply mv_max on it, reducing it to a single value field.
  • another recent functionality we introduced for stats is the filter specific to an aggregation function. stats std_dev(salary) where languages > 3. It would be good to have a test for this as well. For example, FROM employees | stats std_dev(salary_change + 1) where languages > 3, std_dev(salary_change + 1) where languages <= 3, count(*) by gender
  • the row command is a different type of functionality where the source of the data is "static" and it doesn't come from ES. Would be good to have some tests with row as well. For example row a = [1,2,3], b = 5 | stats std_dev(a), std_dev(b) by a

Copy link
Contributor

@ivancea ivancea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding Andrei's suggestions, LGTM!

@dnhatn dnhatn added v8.18.0 auto-backport Automatically create backport pull requests when merged labels Nov 21, 2024
Copy link
Contributor

@craigtaverner craigtaverner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need at least the resolveType method, and probably the tests that Andrei suggests.

@@ -0,0 +1,5 @@
pr: 116531
summary: "[ES|QL] Add a standard deviation function"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the prefix [ES|QL]. The changelog is organized by area, which already says it is ES|QL.

* <a href="https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm">
* Parallel algorithm</a>
*/
public final class WelfordAlgorithm {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like the kind of thing that could move moved to a common place like libs/, but while there is only one usage, it might as well stay here for now.

tag = "docsStatsStdDevNestedExpression"
) }
)
public StdDev(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The params claims to only support the numeric types, but there is no resolveType function to enforce this. It should be similar to the Avg function: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java#L57

@limotova
Copy link
Contributor Author

limotova commented Nov 22, 2024

@astefan

another recent functionality we introduced for stats is the filter specific to an aggregation function. stats std_dev(salary) where languages > 3. It would be good to have a test for this as well. For example, FROM employees | stats std_dev(salary_change + 1) where languages > 3, std_dev(salary_change + 1) where languages <= 3, count(*) by gender

I tried to add in tests with salary_change + 1, but since some of the values are empty it looks like it doesn't handle the + 1 very well (it seems like it ends the calculation as soon as it hits an empty value), is this expected? It works as expected with just salary_change (it skips empty values and calculates everything else), and it also handles something like salary * 2 (not multivalue, but no empty values) fine.
(Actually based on the test failure it might not be working properly... I'm unable to reproduce it locally though and am a bit stumped why it's happening)

the row command is a different type of functionality where the source of the data is "static" and it doesn't come from ES. Would be good to have some tests with row as well. For example row a = [1,2,3], b = 5 | stats std_dev(a), std_dev(b) by a

It seems to work with row as far as I can tell, but I'm not sure I understand what I should be looking for when I add by a here? The standard deviation of only one unique value is always 0, but since a is a multi-value here, the result of std_dev(a) by a is the same as std_dev(a), just the former repeats the result 3 times for each value of a

Copy link
Contributor

@craigtaverner craigtaverner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

*/
@Aggregator(
{
@IntermediateState(name = "mean", type = "DOUBLE"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to specify that these are blocks or we need to always emit them as vectors (with count=0)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I changed it to returning 0 in the intermediate stages

@limotova limotova merged commit 7e801e0 into elastic:main Nov 22, 2024
16 checks passed
@limotova limotova deleted the add-stddev-function branch November 22, 2024 22:33
@elasticsearchmachine
Copy link
Collaborator

💚 Backport successful

Status Branch Result
8.x

limotova added a commit to limotova/elasticsearch that referenced this pull request Nov 22, 2024
Uses Welford's online algorithm, as well as the parallel version, to
calculate standard deviation.
elasticsearchmachine pushed a commit that referenced this pull request Nov 22, 2024
Uses Welford's online algorithm, as well as the parallel version, to
calculate standard deviation.
alexey-ivanov-es pushed a commit to alexey-ivanov-es/elasticsearch that referenced this pull request Nov 28, 2024
Uses Welford's online algorithm, as well as the parallel version, to
calculate standard deviation.
@alex-spies alex-spies mentioned this pull request Dec 2, 2024
99 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
:Analytics/ES|QL AKA ESQL auto-backport Automatically create backport pull requests when merged >enhancement Team:Analytics Meta label for analytical engine team (ESQL/Aggs/Geo) v8.18.0 v9.0.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants