Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/spark-procedures.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ This procedure invalidates all cached Spark plans that reference the affected ta
| Argument Name | Required? | Type | Description |
|---------------|-----------|------|-------------|
| `table` | ✔️ | string | Name of the table to update |
| `snapshot_id` | ✔️ | long | Snapshot ID to set as current |
| `snapshot_id` | | long | Snapshot ID to set as current |
| `ref` | | string | Snapshot Referece (branch or tag) to set as current |

Either `snapshot_id` or `ref` must be provided but not both.

#### Output

Expand All @@ -146,6 +149,11 @@ Set the current snapshot for `db.sample` to 1:
CALL catalog_name.system.set_current_snapshot('db.sample', 1)
```

Set the current snapshot for `db.sample` to tag `s1`:
```sql
CALL catalog_name.system.set_current_snapshot(table => 'db.sample', tag => 's1');
```

### `cherrypick_snapshot`

Cherry-picks changes from a snapshot into the current table state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ public void testInvalidRollbackToSnapshotCases() {

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot('t')", catalogName))
.isInstanceOf(AnalysisException.class)
.hasMessage("Missing required parameters: [snapshot_id]");
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Either snapshot_id or ref must be provided, not both");

Assertions.assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName))
.isInstanceOf(AnalysisException.class)
.hasMessage("Missing required parameters: [snapshot_id]");
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot parse identifier for arg table: 1");

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName))
Expand All @@ -226,8 +226,8 @@ public void testInvalidRollbackToSnapshotCases() {

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName))
.isInstanceOf(AnalysisException.class)
.hasMessage("Missing required parameters: [snapshot_id]");
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Either snapshot_id or ref must be provided, not both");

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName))
Expand All @@ -238,5 +238,58 @@ public void testInvalidRollbackToSnapshotCases() {
() -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot handle an empty identifier for argument table");

Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')",
catalogName))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Either snapshot_id or ref must be provided, not both");
}

@Test
public void testSetCurrentSnapshotToRef() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);

Table table = validationCatalog.loadTable(tableIdent);
Snapshot firstSnapshot = table.currentSnapshot();
String ref = "s1";
sql("ALTER TABLE %s CREATE TAG %s", tableName, ref);

sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

table.refresh();

Snapshot secondSnapshot = table.currentSnapshot();

List<Object[]> output =
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, ref);

assertEquals(
"Procedure output must match",
ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
output);

assertEquals(
"Set must be successful",
ImmutableList.of(row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

String notExistRef = "s2";
Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, notExistRef))
.isInstanceOf(ValidationException.class)
.hasMessage("Cannot find matching snapshot ID for ref " + notExistRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
package org.apache.iceberg.spark.procedures;

import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotRef;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.Identifier;
Expand All @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure {
private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.required("snapshot_id", DataTypes.LongType)
ProcedureParameter.optional("snapshot_id", DataTypes.LongType),
ProcedureParameter.optional("ref", DataTypes.StringType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -78,17 +83,22 @@ public StructType outputType() {
@Override
public InternalRow[] call(InternalRow args) {
Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name());
long snapshotId = args.getLong(1);
Long snapshotId = args.isNullAt(1) ? null : args.getLong(1);
String ref = args.isNullAt(2) ? null : args.getString(2);
Preconditions.checkArgument(
(snapshotId != null && ref == null) || (snapshotId == null && ref != null),
"Either snapshot_id or ref must be provided, not both");

return modifyIcebergTable(
tableIdent,
table -> {
Snapshot previousSnapshot = table.currentSnapshot();
Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null;

table.manageSnapshots().setCurrentSnapshot(snapshotId).commit();
long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref);
table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit();

InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId);
InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId);
return new InternalRow[] {outputRow};
});
}
Expand All @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) {
public String description() {
return "SetCurrentSnapshotProcedure";
}

private long toSnapshotId(Table table, String refName) {
SnapshotRef ref = table.refs().get(refName);
ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName);
return ref.snapshotId();
}
}