diff --git a/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java b/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java new file mode 100644 index 000000000000..17e489746f94 --- /dev/null +++ b/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Test; + + +public class TestRewriteDataFilesProcedure extends SparkExtensionsTestBase { + + public TestRewriteDataFilesProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRewriteDataFilesInEmptyTable() { + createTable(); + List output = sql( + "CALL %s.system.rewrite_data_files('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(0, 0)), + output); + } + + @Test + public void testRewriteDataFilesOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals("Action should rewrite 10 data files and add 2 data files (one per partition) ", + ImmutableList.of(row(10, 2)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesOnNonPartitionTable() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals("Action should rewrite 10 data files and add 1 data files", + ImmutableList.of(row(10, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithOptions() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set the min-input-files = 12, instead of default 5 to skip compacting the files. + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))", + catalogName, tableIdent); + + assertEquals("Action should rewrite 0 data files and add 0 data files", + ImmutableList.of(row(0, 0)), + output); + + List actualRecords = currentData(); + assertEquals("Data should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithSortStrategy() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set sort_order = c1 DESC LAST + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')", + catalogName, tableIdent); + + assertEquals("Action should rewrite 10 data files and add 1 data files", + ImmutableList.of(row(10, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithFilter() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files that may have c1 = 1) + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 1 and c2 is not null')", catalogName, tableIdent); + + assertEquals("Action should rewrite 5 data files (containing c1 = 1) and add 1 data files", + ImmutableList.of(row(5, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 = 'bar') + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c2 = \"bar\"')", catalogName, tableIdent); + + assertEquals("Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + ImmutableList.of(row(5, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithInvalidInputs() { + createTable(); + // create 2 files under non-partitioned table + insertData(2); + + // Test for invalid strategy + AssertHelpers.assertThrows("Should reject calls with unsupported strategy error message", + IllegalArgumentException.class, "unsupported strategy: temp. Only binpack,sort is supported", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " + + "strategy => 'temp')", catalogName, tableIdent)); + + // Test for sort_order with binpack strategy + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Cannot set strategy to sort, it has already been set", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " + + "sort_order => 'c1 ASC NULLS FIRST')", catalogName, tableIdent)); + + // Test for sort_order with invalid null order + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Unable to parse sortOrder:", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 ASC none')", catalogName, tableIdent)); + + // Test for sort_order with invalid sort direction + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Unable to parse sortOrder:", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 none NULLS FIRST')", catalogName, tableIdent)); + + // Test for sort_order with invalid column name + AssertHelpers.assertThrows("Should reject calls with error message", + ValidationException.class, "Cannot find field 'col1' in struct:" + + " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'col1 DESC NULLS FIRST')", catalogName, tableIdent)); + + // Test for sort_order with invalid filter column col1 + AssertHelpers.assertThrows("Should reject calls with error message", + ValidationException.class, "Cannot find field 'col1' in struct:" + + " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', " + + "where => 'col1 = 3')", catalogName, tableIdent)); + } + + @Test + public void testInvalidCasesForRewriteDataFiles() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rewrite_data_files('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.rewrite_data_files('n', 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rewrite_data_files()", catalogName)); + + AssertHelpers.assertThrows("Should reject duplicate arg names name", + AnalysisException.class, "Duplicate procedure argument: table", + () -> sql("CALL %s.system.rewrite_data_files(table => 't', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rewrite_data_files('')", catalogName)); + } + + private void createTable() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", tableName); + } + + private void createPartitionTable() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg PARTITIONED BY (c2)", tableName); + } + + private void insertData(int filesCount) { + ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", "detail1"); + ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", "detail2"); + + List records = Lists.newArrayList(); + IntStream.range(0, filesCount / 2).forEach(i -> { + records.add(record1); + records.add(record2); + }); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).repartition(filesCount); + try { + df.writeTo(tableName).append(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new RuntimeException(e); + } + } + + private List currentData() { + return rowsToJava(spark.sql("SELECT * FROM " + tableName + " order by c1, c2, c3").collectAsList()); + } +} diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java new file mode 100644 index 000000000000..76550f93a33c --- /dev/null +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering; +import org.apache.spark.sql.catalyst.plans.logical.SortOrderParserUtil; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.execution.datasources.SparkExpressionConverter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +/** + * A procedure that rewrites datafiles in a table. + * + * @see org.apache.iceberg.spark.actions.SparkActions#rewriteDataFiles(Table) + */ +class RewriteDataFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("strategy", DataTypes.StringType), + ProcedureParameter.optional("sort_order", DataTypes.StringType), + ProcedureParameter.optional("options", STRING_MAP), + ProcedureParameter.optional("where", DataTypes.StringType) + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("rewritten_data_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("added_data_files_count", DataTypes.IntegerType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new Builder() { + @Override + protected RewriteDataFilesProcedure doBuild() { + return new RewriteDataFilesProcedure(tableCatalog()); + } + }; + } + + private RewriteDataFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + + return modifyIcebergTable(tableIdent, table -> { + RewriteDataFiles action = actions().rewriteDataFiles(table); + + String strategy = args.isNullAt(1) ? null : args.getString(1); + String sortOrderString = args.isNullAt(2) ? null : args.getString(2); + SortOrder sortOrder = null; + if (sortOrderString != null) { + sortOrder = collectSortOrders(table, sortOrderString); + } + if (strategy != null || sortOrder != null) { + action = checkAndApplyStrategy(action, strategy, sortOrder); + } + + if (!args.isNullAt(3)) { + action = checkAndApplyOptions(args, action); + } + + String where = args.isNullAt(4) ? null : args.getString(4); + action = checkAndApplyFilter(action, where); + + RewriteDataFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private RewriteDataFiles checkAndApplyFilter(RewriteDataFiles action, String where) { + if (where != null) { + ParserInterface sqlParser = spark().sessionState().sqlParser(); + try { + Expression expression = sqlParser.parseExpression(where); + return action.filter(SparkExpressionConverter.convertToIcebergExpression(expression)); + } catch (ParseException e) { + throw new IllegalArgumentException("Cannot parse predicates in where option: " + where); + } + } + return action; + } + + private RewriteDataFiles checkAndApplyOptions(InternalRow args, RewriteDataFiles action) { + Map options = Maps.newHashMap(); + args.getMap(3).foreach(DataTypes.StringType, DataTypes.StringType, + (k, v) -> { + options.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + return action.options(options); + } + + private RewriteDataFiles checkAndApplyStrategy(RewriteDataFiles action, String strategy, SortOrder sortOrder) { + // caller of this function ensures that between strategy and sortOrder, at least one of them is not null. + if (strategy == null || strategy.equalsIgnoreCase("sort")) { + return action.sort(sortOrder); + } + if (strategy.equalsIgnoreCase("binpack")) { + RewriteDataFiles rewriteDataFiles = action.binPack(); + if (sortOrder != null) { + // calling below method to throw the error as user has set both binpack strategy and sort order + return rewriteDataFiles.sort(sortOrder); + } + return rewriteDataFiles; + } else { + throw new IllegalArgumentException("unsupported strategy: " + strategy + ". Only binpack,sort is supported"); + } + } + + private SortOrder collectSortOrders(Table table, String sortOrderStr) { + String prefix = "ALTER TABLE temp WRITE ORDERED BY "; + try { + // Note: Reusing the existing Iceberg sql parser to avoid implementing the custom parser for sort orders. + // To reuse the existing parser, adding a prefix of "ALTER TABLE temp WRITE ORDERED BY" + // along with input sort order and parsing it as a plan to collect the sortOrder. + LogicalPlan logicalPlan = spark().sessionState().sqlParser().parsePlan(prefix + sortOrderStr); + return (new SortOrderParserUtil()).collectSortOrder( + table.schema(), + ((SetWriteDistributionAndOrdering) logicalPlan).sortOrder()); + } catch (AnalysisException ex) { + throw new IllegalArgumentException("Unable to parse sortOrder: " + sortOrderStr); + } + } + + private InternalRow[] toOutputRows(RewriteDataFiles.Result result) { + int rewrittenDataFilesCount = result.rewrittenDataFilesCount(); + int addedDataFilesCount = result.addedDataFilesCount(); + InternalRow row = newInternalRow(rewrittenDataFilesCount, addedDataFilesCount); + return new InternalRow[]{row}; + } + + @Override + public String description() { + return "RewriteDataFilesProcedure"; + } +} diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java index 42545abe11d2..4ce9460b90ce 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java @@ -45,6 +45,7 @@ private static Map> initProcedureBuilders() { mapBuilder.put("rollback_to_timestamp", RollbackToTimestampProcedure::builder); mapBuilder.put("set_current_snapshot", SetCurrentSnapshotProcedure::builder); mapBuilder.put("cherrypick_snapshot", CherrypickSnapshotProcedure::builder); + mapBuilder.put("rewrite_data_files", RewriteDataFilesProcedure::builder); mapBuilder.put("rewrite_manifests", RewriteManifestsProcedure::builder); mapBuilder.put("remove_orphan_files", RemoveOrphanFilesProcedure::builder); mapBuilder.put("expire_snapshots", ExpireSnapshotsProcedure::builder); diff --git a/spark/v3.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala b/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala similarity index 100% rename from spark/v3.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala rename to spark/v3.1/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala diff --git a/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala b/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala new file mode 100644 index 000000000000..bf19ef8a2167 --- /dev/null +++ b/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.NullOrder +import org.apache.iceberg.Schema +import org.apache.iceberg.SortDirection +import org.apache.iceberg.SortOrder +import org.apache.iceberg.expressions.Term + +class SortOrderParserUtil { + + def collectSortOrder(tableSchema:Schema, sortOrder: Seq[(Term, SortDirection, NullOrder)]): SortOrder = { + val orderBuilder = SortOrder.builderFor(tableSchema) + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.build(); + } +} diff --git a/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala b/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala new file mode 100644 index 000000000000..c41852713d1a --- /dev/null +++ b/spark/v3.1/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.iceberg.spark.SparkFilters +import org.apache.spark.sql.catalyst.expressions.Expression + +object SparkExpressionConverter { + + def convertToIcebergExpression(sparkExpression: Expression): org.apache.iceberg.expressions.Expression = { + // Currently, it is a double conversion as we are converting Spark expression to Spark filter + // and then converting Spark filter to Iceberg expression. + // But these two conversions already exist and well tested. So, we are going with this approach. + SparkFilters.convert(DataSourceStrategy.translateFilter(sparkExpression, supportNestedPredicatePushdown = true).get) + } +}