Shortcuts

Hyperparameter tuning with Ray Tune

Created On: Aug 31, 2020 | Last Updated: Oct 31, 2024 | Last Verified: Nov 05, 2024

Hyperparameter tuning can make the difference between an average model and a highly accurate one. Often simple things like choosing a different learning rate or changing a network layer size can have a dramatic impact on your model performance.

Fortunately, there are tools that help with finding the best combination of parameters. Ray Tune is an industry standard tool for distributed hyperparameter tuning. Ray Tune includes the latest hyperparameter search algorithms, integrates with various analysis libraries, and natively supports distributed training through Ray’s distributed machine learning engine.

In this tutorial, we will show you how to integrate Ray Tune into your PyTorch training workflow. We will extend this tutorial from the PyTorch documentation for training a CIFAR10 image classifier.

As you will see, we only need to add some slight modifications. In particular, we need to

  1. wrap data loading and training in functions,

  2. make some network parameters configurable,

  3. add checkpointing (optional),

  4. and define the search space for the model tuning


To run this tutorial, please make sure the following packages are installed:

  • ray[tune]: Distributed hyperparameter tuning library

  • torchvision: For the data transformers

Setup / Imports

Let’s start with the imports:

from functools import partial
import os
import tempfile
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray import train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle

Most of the imports are needed for building the PyTorch model. Only the last imports are for Ray Tune.

Data loaders

We wrap the data loaders in their own function and pass a global data directory. This way we can share a data directory between different trials.

def load_data(data_dir="./data"):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )

    return trainset, testset

Configurable neural network

We can only tune those parameters that are configurable. In this example, we can specify the layer sizes of the fully connected layers:

class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The train function

Now it gets interesting, because we introduce some changes to the example from the PyTorch documentation.

We wrap the training script in a function train_cifar(config, data_dir=None). The config parameter will receive the hyperparameters we would like to train with. The data_dir specifies the directory where we load and store the data, so that multiple runs can share the same data source. We also load the model and optimizer state at the start of the run, if a checkpoint is provided. Further down in this tutorial you will find information on how to save the checkpoint and what it is used for.

net = Net(config["l1"], config["l2"])

checkpoint = get_checkpoint()
if checkpoint:
    with checkpoint.as_directory() as checkpoint_dir:
        data_path = Path(checkpoint_dir) / "data.pkl"
        with open(data_path, "rb") as fp:
            checkpoint_state = pickle.load(fp)
        start_epoch = checkpoint_state["epoch"]
        net.load_state_dict(checkpoint_state["net_state_dict"])
        optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
else:
    start_epoch = 0

The learning rate of the optimizer is made configurable, too:

optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

We also split the training data into a training and validation subset. We thus train on 80% of the data and calculate the validation loss on the remaining 20%. The batch sizes with which we iterate through the training and test sets are configurable as well.

Adding (multi) GPU support with DataParallel

Image classification benefits largely from GPUs. Luckily, we can continue to use PyTorch’s abstractions in Ray Tune. Thus, we can wrap our model in nn.DataParallel to support data parallel training on multiple GPUs:

device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
net.to(device)

By using a device variable we make sure that training also works when we have no GPUs available. PyTorch requires us to send our data to the GPU memory explicitly, like this:

for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

The code now supports training on CPUs, on a single GPU, and on multiple GPUs. Notably, Ray also supports fractional GPUs so we can share GPUs among trials, as long as the model still fits on the GPU memory. We’ll come back to that later.

Communicating with Ray Tune

The most interesting part is the communication with Ray Tune:

checkpoint_data = {
    "epoch": epoch,
    "net_state_dict": net.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}
with tempfile.TemporaryDirectory() as checkpoint_dir:
    data_path = Path(checkpoint_dir) / "data.pkl"
    with open(data_path, "wb") as fp:
        pickle.dump(checkpoint_data, fp)

    checkpoint = Checkpoint.from_directory(checkpoint_dir)
    train.report(
        {"loss": val_loss / val_steps, "accuracy": correct / total},
        checkpoint=checkpoint,
    )

Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically, we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics to decide which hyperparameter configuration lead to the best results. These metrics can also be used to stop bad performing trials early in order to avoid wasting resources on those trials.

The checkpoint saving is optional, however, it is necessary if we wanted to use advanced schedulers like Population Based Training. Also, by saving the checkpoint we can later load the trained models and validate them on a test set. Lastly, saving checkpoints is useful for fault tolerance, and it allows us to interrupt training and continue training later.

Full training function

The full code example looks like this:

def train_cifar(config, data_dir=None):
    net = Net(config["l1"], config["l2"])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    checkpoint = get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "rb") as fp:
                checkpoint_state = pickle.load(fp)
            start_epoch = checkpoint_state["epoch"]
            net.load_state_dict(checkpoint_state["net_state_dict"])
            optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
    else:
        start_epoch = 0

    trainset, testset = load_data(data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs]
    )

    trainloader = torch.utils.data.DataLoader(
        train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
    )
    valloader = torch.utils.data.DataLoader(
        val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
    )

    for epoch in range(start_epoch, 10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(
                    "[%d, %5d] loss: %.3f"
                    % (epoch + 1, i + 1, running_loss / epoch_steps)
                )
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        checkpoint_data = {
            "epoch": epoch,
            "net_state_dict": net.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        with tempfile.TemporaryDirectory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "wb") as fp:
                pickle.dump(checkpoint_data, fp)

            checkpoint = Checkpoint.from_directory(checkpoint_dir)
            train.report(
                {"loss": val_loss / val_steps, "accuracy": correct / total},
                checkpoint=checkpoint,
            )

    print("Finished Training")

As you can see, most of the code is adapted directly from the original example.

Test set accuracy

Commonly the performance of a machine learning model is tested on a hold-out test set with data that has not been used for training the model. We also wrap this in a function:

def test_accuracy(net, device="cpu"):
    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2
    )

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

The function also expects a device parameter, so we can do the test set validation on a GPU.

Configuring the search space

Lastly, we need to define Ray Tune’s search space. Here is an example:

config = {
    "l1": tune.choice([2 ** i for i in range(9)]),
    "l2": tune.choice([2 ** i for i in range(9)]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16])
}

The tune.choice() accepts a list of values that are uniformly sampled from. In this example, the l1 and l2 parameters should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256. The lr (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly, the batch size is a choice between 2, 4, 8, and 16.

At each trial, Ray Tune will now randomly sample a combination of parameters from these search spaces. It will then train a number of models in parallel and find the best performing one among these. We also use the ASHAScheduler which will terminate bad performing trials early.

We wrap the train_cifar function with functools.partial to set the constant data_dir parameter. We can also tell Ray Tune what resources should be available for each trial:

gpus_per_trial = 2
# ...
result = tune.run(
    partial(train_cifar, data_dir=data_dir),
    resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},
    config=config,
    num_samples=num_samples,
    scheduler=scheduler,
    checkpoint_at_end=True)

You can specify the number of CPUs, which are then available e.g. to increase the num_workers of the PyTorch DataLoader instances. The selected number of GPUs are made visible to PyTorch in each trial. Trials do not have access to GPUs that haven’t been requested for them - so you don’t have to care about two trials using the same set of resources.

Here we can also specify fractional GPUs, so something like gpus_per_trial=0.5 is completely valid. The trials will then share GPUs among each other. You just have to make sure that the models still fit in the GPU memory.

After training the models, we will find the best performing one and load the trained network from the checkpoint file. We then obtain the test set accuracy and report everything by printing.

The full main function looks like this:

def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
    data_dir = os.path.abspath("./data")
    load_data(data_dir)
    config = {
        "l1": tune.choice([2**i for i in range(9)]),
        "l2": tune.choice([2**i for i in range(9)]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16]),
    }
    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2,
    )
    result = tune.run(
        partial(train_cifar, data_dir=data_dir),
        resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
    )

    best_trial = result.get_best_trial("loss", "min", "last")
    print(f"Best trial config: {best_trial.config}")
    print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
    print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")

    best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if gpus_per_trial > 1:
            best_trained_model = nn.DataParallel(best_trained_model)
    best_trained_model.to(device)

    best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric="accuracy", mode="max")
    with best_checkpoint.as_directory() as checkpoint_dir:
        data_path = Path(checkpoint_dir) / "data.pkl"
        with open(data_path, "rb") as fp:
            best_checkpoint_data = pickle.load(fp)

        best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
        test_acc = test_accuracy(best_trained_model, device)
        print("Best trial test set accuracy: {}".format(test_acc))


if __name__ == "__main__":
    # You can change the number of GPUs per trial here:
    main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
  0% 0.00/170M [00:00<?, ?B/s]
  0% 459k/170M [00:00<00:37, 4.55MB/s]
  5% 7.86M/170M [00:00<00:03, 45.3MB/s]
 11% 18.4M/170M [00:00<00:02, 72.4MB/s]
 17% 28.7M/170M [00:00<00:01, 84.5MB/s]
 23% 39.1M/170M [00:00<00:01, 91.4MB/s]
 29% 49.3M/170M [00:00<00:01, 95.3MB/s]
 35% 59.8M/170M [00:00<00:01, 98.2MB/s]
 41% 70.0M/170M [00:00<00:01, 99.5MB/s]
 47% 80.4M/170M [00:00<00:00, 101MB/s]
 53% 90.6M/170M [00:01<00:00, 101MB/s]
 59% 101M/170M [00:01<00:00, 91.2MB/s]
 65% 110M/170M [00:01<00:00, 81.7MB/s]
 69% 118M/170M [00:01<00:00, 80.2MB/s]
 74% 127M/170M [00:01<00:00, 79.2MB/s]
 79% 135M/170M [00:01<00:00, 78.2MB/s]
 84% 143M/170M [00:01<00:00, 77.5MB/s]
 88% 151M/170M [00:01<00:00, 77.1MB/s]
 93% 158M/170M [00:01<00:00, 76.7MB/s]
 97% 166M/170M [00:02<00:00, 76.5MB/s]
100% 170M/170M [00:02<00:00, 82.1MB/s]
2025-06-17 14:24:37,917 WARNING services.py:1889 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 2147467264 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.
2025-06-17 14:24:37,971 INFO worker.py:1642 -- Started a local Ray instance.
2025-06-17 14:24:38,908 INFO tune.py:228 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `tune.run(...)`.
2025-06-17 14:24:38,910 INFO tune.py:654 -- [output] This will use the new output engine with verbosity 2. To disable the new output and use the legacy output engine, set the environment variable RAY_AIR_NEW_OUTPUT=0. For more information, please see https://github.com/ray-project/ray/issues/36949
+--------------------------------------------------------------------+
| Configuration for experiment     train_cifar_2025-06-17_14-24-38   |
+--------------------------------------------------------------------+
| Search algorithm                 BasicVariantGenerator             |
| Scheduler                        AsyncHyperBandScheduler           |
| Number of trials                 10                                |
+--------------------------------------------------------------------+

View detailed results here: /var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38
To visualize your results with TensorBoard, run: `tensorboard --logdir /var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38`

Trial status: 10 PENDING
Current time: 2025-06-17 14:24:39. Total running time: 0s
Logical resource usage: 10.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+-------------------------------------------------------------------------------+
| Trial name                status       l1     l2            lr     batch_size |
+-------------------------------------------------------------------------------+
| train_cifar_cd391_00000   PENDING     256      2   0.00105263               2 |
| train_cifar_cd391_00001   PENDING      64      2   0.0189753               16 |
| train_cifar_cd391_00002   PENDING      16    256   0.0450584                2 |
| train_cifar_cd391_00003   PENDING       8     16   0.00920872               2 |
| train_cifar_cd391_00004   PENDING      64     16   0.000310926              2 |
| train_cifar_cd391_00005   PENDING       4      1   0.00322626               4 |
| train_cifar_cd391_00006   PENDING       1     16   0.000669639              4 |
| train_cifar_cd391_00007   PENDING      64      1   0.00143856               2 |
| train_cifar_cd391_00008   PENDING      32     64   0.00411186               8 |
| train_cifar_cd391_00009   PENDING     256     64   0.000399319              8 |
+-------------------------------------------------------------------------------+

Trial train_cifar_cd391_00004 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00004 config             |
+--------------------------------------------------+
| batch_size                                     2 |
| l1                                            64 |
| l2                                            16 |
| lr                                       0.00031 |
+--------------------------------------------------+

Trial train_cifar_cd391_00000 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00000 config             |
+--------------------------------------------------+
| batch_size                                     2 |
| l1                                           256 |
| l2                                             2 |
| lr                                       0.00105 |
+--------------------------------------------------+

Trial train_cifar_cd391_00003 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00003 config             |
+--------------------------------------------------+
| batch_size                                     2 |
| l1                                             8 |
| l2                                            16 |
| lr                                       0.00921 |
+--------------------------------------------------+

Trial train_cifar_cd391_00002 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00002 config             |
+--------------------------------------------------+
| batch_size                                     2 |
| l1                                            16 |
| l2                                           256 |
| lr                                       0.04506 |
+--------------------------------------------------+

Trial train_cifar_cd391_00006 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00006 config             |
+--------------------------------------------------+
| batch_size                                     4 |
| l1                                             1 |
| l2                                            16 |
| lr                                       0.00067 |
+--------------------------------------------------+

Trial train_cifar_cd391_00005 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00005 config             |
+--------------------------------------------------+
| batch_size                                     4 |
| l1                                             4 |
| l2                                             1 |
| lr                                       0.00323 |
+--------------------------------------------------+

Trial train_cifar_cd391_00001 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00001 config             |
+--------------------------------------------------+
| batch_size                                    16 |
| l1                                            64 |
| l2                                             2 |
| lr                                       0.01898 |
+--------------------------------------------------+

Trial train_cifar_cd391_00007 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00007 config             |
+--------------------------------------------------+
| batch_size                                     2 |
| l1                                            64 |
| l2                                             1 |
| lr                                       0.00144 |
+--------------------------------------------------+
(func pid=4879) [1,  2000] loss: 2.301
(func pid=4873) [1,  2000] loss: nan

Trial train_cifar_cd391_00001 finished iteration 1 at 2025-06-17 14:25:06. Total running time: 27s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                  22.96926 |
| time_total_s                                      22.96926 |
| training_iteration                                       1 |
| accuracy                                            0.1913 |
| loss                                               1.93487 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000000
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000000)
(func pid=4874) [1,  4000] loss: 1.107 [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(func pid=4873) [1,  4000] loss: nan

Trial status: 8 RUNNING | 2 PENDING
Current time: 2025-06-17 14:25:09. Total running time: 30s
Logical resource usage: 16.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+----------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy |
+----------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING     256      2   0.00105263               2                                                    |
| train_cifar_cd391_00001   RUNNING      64      2   0.0189753               16        1            22.9693   1.93487       0.1913 |
| train_cifar_cd391_00002   RUNNING      16    256   0.0450584                2                                                    |
| train_cifar_cd391_00003   RUNNING       8     16   0.00920872               2                                                    |
| train_cifar_cd391_00004   RUNNING      64     16   0.000310926              2                                                    |
| train_cifar_cd391_00005   RUNNING       4      1   0.00322626               4                                                    |
| train_cifar_cd391_00006   RUNNING       1     16   0.000669639              4                                                    |
| train_cifar_cd391_00007   RUNNING      64      1   0.00143856               2                                                    |
| train_cifar_cd391_00008   PENDING      32     64   0.00411186               8                                                    |
| train_cifar_cd391_00009   PENDING     256     64   0.000399319              8                                                    |
+----------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [1,  6000] loss: nan
(func pid=4880) [1,  4000] loss: 1.022 [repeated 5x across cluster]
(func pid=4866) [2,  2000] loss: 1.934 [repeated 7x across cluster]
(func pid=4873) [1,  8000] loss: nan
(func pid=4882) [1,  8000] loss: 0.576 [repeated 5x across cluster]

Trial train_cifar_cd391_00001 finished iteration 2 at 2025-06-17 14:25:30. Total running time: 51s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                  23.81307 |
| time_total_s                                      46.78233 |
| training_iteration                                       2 |
| accuracy                                            0.1859 |
| loss                                               1.96024 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000001
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000001)
(func pid=4874) [1, 10000] loss: 0.462 [repeated 2x across cluster]
(func pid=4873) [1, 10000] loss: nan

Trial status: 8 RUNNING | 2 PENDING
Current time: 2025-06-17 14:25:39. Total running time: 1min 0s
Logical resource usage: 16.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+----------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy |
+----------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING     256      2   0.00105263               2                                                    |
| train_cifar_cd391_00001   RUNNING      64      2   0.0189753               16        2            46.7823   1.96024       0.1859 |
| train_cifar_cd391_00002   RUNNING      16    256   0.0450584                2                                                    |
| train_cifar_cd391_00003   RUNNING       8     16   0.00920872               2                                                    |
| train_cifar_cd391_00004   RUNNING      64     16   0.000310926              2                                                    |
| train_cifar_cd391_00005   RUNNING       4      1   0.00322626               4                                                    |
| train_cifar_cd391_00006   RUNNING       1     16   0.000669639              4                                                    |
| train_cifar_cd391_00007   RUNNING      64      1   0.00143856               2                                                    |
| train_cifar_cd391_00008   PENDING      32     64   0.00411186               8                                                    |
| train_cifar_cd391_00009   PENDING     256     64   0.000399319              8                                                    |
+----------------------------------------------------------------------------------------------------------------------------------+
(func pid=4866) [3,  2000] loss: 1.932 [repeated 6x across cluster]
(func pid=4873) [1, 12000] loss: nan

Trial train_cifar_cd391_00006 finished iteration 1 at 2025-06-17 14:25:50. Total running time: 1min 11s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                  67.02508 |
| time_total_s                                      67.02508 |
| training_iteration                                       1 |
| accuracy                                            0.2143 |
| loss                                               1.89611 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000000
(func pid=4881) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000000)

Trial train_cifar_cd391_00005 finished iteration 1 at 2025-06-17 14:25:51. Total running time: 1min 12s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                  67.52965 |
| time_total_s                                      67.52965 |
| training_iteration                                       1 |
| accuracy                                            0.2206 |
| loss                                               1.90983 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00001 finished iteration 3 at 2025-06-17 14:25:53. Total running time: 1min 14s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                  22.44391 |
| time_total_s                                      69.22624 |
| training_iteration                                       3 |
| accuracy                                            0.1988 |
| loss                                               1.92248 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000002
(func pid=4874) [1, 14000] loss: 0.330 [repeated 5x across cluster]
(func pid=4873) [1, 14000] loss: nan
(func pid=4880) [2,  2000] loss: 1.918 [repeated 4x across cluster]
(func pid=4873) [1, 16000] loss: nan

Trial status: 8 RUNNING | 2 PENDING
Current time: 2025-06-17 14:26:09. Total running time: 1min 30s
Logical resource usage: 16.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+----------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy |
+----------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING     256      2   0.00105263               2                                                    |
| train_cifar_cd391_00001   RUNNING      64      2   0.0189753               16        3            69.2262   1.92248       0.1988 |
| train_cifar_cd391_00002   RUNNING      16    256   0.0450584                2                                                    |
| train_cifar_cd391_00003   RUNNING       8     16   0.00920872               2                                                    |
| train_cifar_cd391_00004   RUNNING      64     16   0.000310926              2                                                    |
| train_cifar_cd391_00005   RUNNING       4      1   0.00322626               4        1            67.5297   1.90983       0.2206 |
| train_cifar_cd391_00006   RUNNING       1     16   0.000669639              4        1            67.0251   1.89611       0.2143 |
| train_cifar_cd391_00007   RUNNING      64      1   0.00143856               2                                                    |
| train_cifar_cd391_00008   PENDING      32     64   0.00411186               8                                                    |
| train_cifar_cd391_00009   PENDING     256     64   0.000399319              8                                                    |
+----------------------------------------------------------------------------------------------------------------------------------+
(func pid=4882) [1, 16000] loss: 0.288 [repeated 5x across cluster]
(func pid=4881) [2,  4000] loss: 0.951 [repeated 2x across cluster]

Trial train_cifar_cd391_00001 finished iteration 4 at 2025-06-17 14:26:15. Total running time: 1min 36s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                   22.3305 |
| time_total_s                                      91.55674 |
| training_iteration                                       4 |
| accuracy                                             0.254 |
| loss                                               1.92441 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000003
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000003) [repeated 3x across cluster]
(func pid=4873) [1, 18000] loss: nan
(func pid=4882) [1, 18000] loss: 0.256 [repeated 4x across cluster]
(func pid=4881) [2,  6000] loss: 0.630 [repeated 2x across cluster]
(func pid=4873) [1, 20000] loss: nan
(func pid=4882) [1, 20000] loss: 0.230 [repeated 4x across cluster]

Trial train_cifar_cd391_00001 finished iteration 5 at 2025-06-17 14:26:36. Total running time: 1min 57s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  21.08606 |
| time_total_s                                     112.64281 |
| training_iteration                                       5 |
| accuracy                                            0.2978 |
| loss                                               1.82417 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000004
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000004)

Trial status: 8 RUNNING | 2 PENDING
Current time: 2025-06-17 14:26:39. Total running time: 2min 0s
Logical resource usage: 16.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+----------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status       l1     l2            lr     batch_size     iter     total time (s)      loss     accuracy |
+----------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING     256      2   0.00105263               2                                                    |
| train_cifar_cd391_00001   RUNNING      64      2   0.0189753               16        5           112.643    1.82417       0.2978 |
| train_cifar_cd391_00002   RUNNING      16    256   0.0450584                2                                                    |
| train_cifar_cd391_00003   RUNNING       8     16   0.00920872               2                                                    |
| train_cifar_cd391_00004   RUNNING      64     16   0.000310926              2                                                    |
| train_cifar_cd391_00005   RUNNING       4      1   0.00322626               4        1            67.5297   1.90983       0.2206 |
| train_cifar_cd391_00006   RUNNING       1     16   0.000669639              4        1            67.0251   1.89611       0.2143 |
| train_cifar_cd391_00007   RUNNING      64      1   0.00143856               2                                                    |
| train_cifar_cd391_00008   PENDING      32     64   0.00411186               8                                                    |
| train_cifar_cd391_00009   PENDING     256     64   0.000399319              8                                                    |
+----------------------------------------------------------------------------------------------------------------------------------+
(func pid=4881) [2, 10000] loss: 0.375 [repeated 5x across cluster]

Trial train_cifar_cd391_00003 finished iteration 1 at 2025-06-17 14:26:47. Total running time: 2min 8s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00003 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                 124.22614 |
| time_total_s                                     124.22614 |
| training_iteration                                       1 |
| accuracy                                            0.0994 |
| loss                                                2.3138 |
+------------------------------------------------------------+
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00003_3_batch_size=2,l1=8,l2=16,lr=0.0092_2025-06-17_14-24-39/checkpoint_000000)
Trial train_cifar_cd391_00003 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00003_3_batch_size=2,l1=8,l2=16,lr=0.0092_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00003 completed after 1 iterations at 2025-06-17 14:26:47. Total running time: 2min 8s

Trial train_cifar_cd391_00008 started with configuration:
+--------------------------------------------------+
| Trial train_cifar_cd391_00008 config             |
+--------------------------------------------------+
| batch_size                                     8 |
| l1                                            32 |
| l2                                            64 |
| lr                                       0.00411 |
+--------------------------------------------------+

Trial train_cifar_cd391_00002 finished iteration 1 at 2025-06-17 14:26:48. Total running time: 2min 9s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                 124.45894 |
| time_total_s                                     124.45894 |
| training_iteration                                       1 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00004 finished iteration 1 at 2025-06-17 14:26:49. Total running time: 2min 10s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                 126.16599 |
| time_total_s                                     126.16599 |
| training_iteration                                       1 |
| accuracy                                            0.4351 |
| loss                                               1.54767 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00007 finished iteration 1 at 2025-06-17 14:26:49. Total running time: 2min 11s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00007 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                 126.11598 |
| time_total_s                                     126.11598 |
| training_iteration                                       1 |
| accuracy                                            0.0995 |
| loss                                               2.30463 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00007 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00007_7_batch_size=2,l1=64,l2=1,lr=0.0014_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00007 completed after 1 iterations at 2025-06-17 14:26:50. Total running time: 2min 11s

Trial train_cifar_cd391_00009 started with configuration:
+-------------------------------------------------+
| Trial train_cifar_cd391_00009 config            |
+-------------------------------------------------+
| batch_size                                    8 |
| l1                                          256 |
| l2                                           64 |
| lr                                       0.0004 |
+-------------------------------------------------+
(func pid=4866) [6,  2000] loss: 1.829 [repeated 2x across cluster]

Trial train_cifar_cd391_00000 finished iteration 1 at 2025-06-17 14:26:51. Total running time: 2min 12s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                 128.06979 |
| time_total_s                                     128.06979 |
| training_iteration                                       1 |
| accuracy                                            0.2913 |
| loss                                               1.84174 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00005 finished iteration 2 at 2025-06-17 14:26:55. Total running time: 2min 16s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                  64.44225 |
| time_total_s                                      131.9719 |
| training_iteration                                       2 |
| accuracy                                            0.2019 |
| loss                                               1.91756 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000001
(func pid=4880) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000001) [repeated 5x across cluster]

Trial train_cifar_cd391_00006 finished iteration 2 at 2025-06-17 14:26:56. Total running time: 2min 17s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                  65.59584 |
| time_total_s                                     132.62092 |
| training_iteration                                       2 |
| accuracy                                            0.2162 |
| loss                                               1.85498 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000001

Trial train_cifar_cd391_00001 finished iteration 6 at 2025-06-17 14:26:56. Total running time: 2min 17s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                    20.222 |
| time_total_s                                      132.8648 |
| training_iteration                                       6 |
| accuracy                                            0.3033 |
| loss                                               1.90369 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000005
(func pid=4873) [2,  2000] loss: nan
(func pid=4874) [1,  2000] loss: 1.936
(func pid=4880) [3,  2000] loss: 1.900 [repeated 4x across cluster]

Trial status: 8 RUNNING | 2 TERMINATED
Current time: 2025-06-17 14:27:09. Total running time: 2min 30s
Logical resource usage: 16.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        1            128.07      1.84174       0.2913 |
| train_cifar_cd391_00001   RUNNING        64      2   0.0189753               16        6            132.865     1.90369       0.3033 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        1            124.459   nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        1            126.166     1.54767       0.4351 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        2            131.972     1.91756       0.2019 |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        2            132.621     1.85498       0.2162 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8                                                      |
| train_cifar_cd391_00009   RUNNING       256     64   0.000399319              8                                                      |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1            124.226     2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1            126.116     2.30463       0.0995 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [2,  4000] loss: nan
(func pid=4874) [1,  4000] loss: 0.807 [repeated 4x across cluster]
(func pid=4880) [3,  4000] loss: 0.948 [repeated 3x across cluster]

Trial train_cifar_cd391_00001 finished iteration 7 at 2025-06-17 14:27:20. Total running time: 2min 41s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                  23.51324 |
| time_total_s                                     156.37804 |
| training_iteration                                       7 |
| accuracy                                            0.3219 |
| loss                                                 1.773 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000006
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000006) [repeated 3x across cluster]
(func pid=4873) [2,  6000] loss: nan
(func pid=4863) [2,  6000] loss: 0.572 [repeated 3x across cluster]

Trial train_cifar_cd391_00008 finished iteration 1 at 2025-06-17 14:27:28. Total running time: 2min 49s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                  40.55282 |
| time_total_s                                      40.55282 |
| training_iteration                                       1 |
| accuracy                                            0.4476 |
| loss                                               1.51272 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000000
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000000)
(func pid=4873) [2,  8000] loss: nan

Trial train_cifar_cd391_00009 finished iteration 1 at 2025-06-17 14:27:33. Total running time: 2min 54s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00009 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000000 |
| time_this_iter_s                                  43.20764 |
| time_total_s                                      43.20764 |
| training_iteration                                       1 |
| accuracy                                            0.2744 |
| loss                                               1.95413 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00009 saved a checkpoint for iteration 1 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00009_9_batch_size=8,l1=256,l2=64,lr=0.0004_2025-06-17_14-24-39/checkpoint_000000

Trial train_cifar_cd391_00009 completed after 1 iterations at 2025-06-17 14:27:33. Total running time: 2min 54s
(func pid=4882) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00009_9_batch_size=8,l1=256,l2=64,lr=0.0004_2025-06-17_14-24-39/checkpoint_000000)
(func pid=4879) [2,  8000] loss: 0.373 [repeated 3x across cluster]

Trial status: 7 RUNNING | 3 TERMINATED
Current time: 2025-06-17 14:27:39. Total running time: 3min 0s
Logical resource usage: 14.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        1           128.07       1.84174       0.2913 |
| train_cifar_cd391_00001   RUNNING        64      2   0.0189753               16        7           156.378      1.773         0.3219 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        1           124.459    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        1           126.166      1.54767       0.4351 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        2           131.972      1.91756       0.2019 |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        2           132.621      1.85498       0.2162 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        1            40.5528     1.51272       0.4476 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4880) [3,  8000] loss: 0.474 [repeated 3x across cluster]
(func pid=4873) [2, 10000] loss: nan

Trial train_cifar_cd391_00001 finished iteration 8 at 2025-06-17 14:27:41. Total running time: 3min 2s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                  20.66375 |
| time_total_s                                     177.04179 |
| training_iteration                                       8 |
| accuracy                                            0.2983 |
| loss                                               1.81638 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000007
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000007)
(func pid=4863) [2, 10000] loss: 0.336 [repeated 4x across cluster]
(func pid=4873) [2, 12000] loss: nan
(func pid=4874) [2,  4000] loss: 0.710 [repeated 3x across cluster]
(func pid=4873) [2, 14000] loss: nan
(func pid=4863) [2, 12000] loss: 0.278 [repeated 3x across cluster]

Trial train_cifar_cd391_00005 finished iteration 3 at 2025-06-17 14:28:00. Total running time: 3min 21s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                   64.4185 |
| time_total_s                                      196.3904 |
| training_iteration                                       3 |
| accuracy                                             0.232 |
| loss                                               1.85493 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000002

Trial train_cifar_cd391_00001 finished iteration 9 at 2025-06-17 14:28:00. Total running time: 3min 21s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000008 |
| time_this_iter_s                                   19.2195 |
| time_total_s                                     196.26129 |
| training_iteration                                       9 |
| accuracy                                            0.2735 |
| loss                                               1.91524 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 9 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000008
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000008)

Trial train_cifar_cd391_00006 finished iteration 3 at 2025-06-17 14:28:01. Total running time: 3min 22s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                  64.60942 |
| time_total_s                                     197.23034 |
| training_iteration                                       3 |
| accuracy                                            0.2399 |
| loss                                               1.81956 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000002

Trial train_cifar_cd391_00008 finished iteration 2 at 2025-06-17 14:28:03. Total running time: 3min 24s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                  35.25613 |
| time_total_s                                      75.80894 |
| training_iteration                                       2 |
| accuracy                                            0.5241 |
| loss                                               1.34218 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000001
(func pid=4863) [2, 14000] loss: 0.239 [repeated 2x across cluster]
(func pid=4873) [2, 16000] loss: nan

Trial status: 7 RUNNING | 3 TERMINATED
Current time: 2025-06-17 14:28:09. Total running time: 3min 30s
Logical resource usage: 14.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        1           128.07       1.84174       0.2913 |
| train_cifar_cd391_00001   RUNNING        64      2   0.0189753               16        9           196.261      1.91524       0.2735 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        1           124.459    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        1           126.166      1.54767       0.4351 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        3           196.39       1.85493       0.232  |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        3           197.23       1.81956       0.2399 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        2            75.8089     1.34218       0.5241 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4879) [2, 16000] loss: 0.179
(func pid=4880) [4,  2000] loss: 1.880
(func pid=4873) [2, 18000] loss: nan
(func pid=4874) [3,  2000] loss: 1.361 [repeated 4x across cluster]

Trial train_cifar_cd391_00001 finished iteration 10 at 2025-06-17 14:28:19. Total running time: 3min 40s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00001 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000009 |
| time_this_iter_s                                  19.05792 |
| time_total_s                                     215.31921 |
| training_iteration                                      10 |
| accuracy                                            0.2731 |
| loss                                               1.87434 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00001 saved a checkpoint for iteration 10 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000009

Trial train_cifar_cd391_00001 completed after 10 iterations at 2025-06-17 14:28:19. Total running time: 3min 40s
(func pid=4866) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00001_1_batch_size=16,l1=64,l2=2,lr=0.0190_2025-06-17_14-24-39/checkpoint_000009) [repeated 4x across cluster]
(func pid=4863) [2, 18000] loss: 0.188 [repeated 4x across cluster]
(func pid=4873) [2, 20000] loss: nan
(func pid=4880) [4,  6000] loss: 0.632 [repeated 3x across cluster]

Trial train_cifar_cd391_00008 finished iteration 3 at 2025-06-17 14:28:33. Total running time: 3min 54s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                  30.01278 |
| time_total_s                                     105.82172 |
| training_iteration                                       3 |
| accuracy                                            0.5197 |
| loss                                               1.37711 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000002
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000002)
(func pid=4880) [4,  8000] loss: 0.476 [repeated 3x across cluster]

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:28:39. Total running time: 4min 0s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        1           128.07       1.84174       0.2913 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        1           124.459    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        1           126.166      1.54767       0.4351 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        3           196.39       1.85493       0.232  |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        3           197.23       1.81956       0.2399 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        3           105.822      1.37711       0.5197 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000001)

Trial train_cifar_cd391_00002 finished iteration 2 at 2025-06-17 14:28:40. Total running time: 4min 1s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                 112.03367 |
| time_total_s                                     236.49261 |
| training_iteration                                       2 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000001

Trial train_cifar_cd391_00004 finished iteration 2 at 2025-06-17 14:28:42. Total running time: 4min 3s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                 113.09564 |
| time_total_s                                     239.26163 |
| training_iteration                                       2 |
| accuracy                                            0.4837 |
| loss                                               1.41163 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000001
(func pid=4874) [4,  2000] loss: 1.285 [repeated 2x across cluster]

Trial train_cifar_cd391_00000 finished iteration 2 at 2025-06-17 14:28:45. Total running time: 4min 6s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000001 |
| time_this_iter_s                                 113.85691 |
| time_total_s                                      241.9267 |
| training_iteration                                       2 |
| accuracy                                            0.3553 |
| loss                                                1.6424 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 2 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000001
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000001) [repeated 2x across cluster]
(func pid=4873) [3,  2000] loss: nan
(func pid=4881) [4, 10000] loss: 0.368 [repeated 2x across cluster]

Trial train_cifar_cd391_00005 finished iteration 4 at 2025-06-17 14:28:53. Total running time: 4min 14s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                  52.87169 |
| time_total_s                                     249.26209 |
| training_iteration                                       4 |
| accuracy                                             0.218 |
| loss                                               1.85685 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000003
(func pid=4880) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000003)

Trial train_cifar_cd391_00006 finished iteration 4 at 2025-06-17 14:28:53. Total running time: 4min 14s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                  52.33475 |
| time_total_s                                     249.56509 |
| training_iteration                                       4 |
| accuracy                                            0.2388 |
| loss                                               1.82043 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000003
(func pid=4881) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000003)
(func pid=4863) [3,  2000] loss: 1.611 [repeated 3x across cluster]
(func pid=4873) [3,  4000] loss: nan
(func pid=4879) [3,  4000] loss: 0.686

Trial train_cifar_cd391_00008 finished iteration 4 at 2025-06-17 14:29:01. Total running time: 4min 22s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                  28.01891 |
| time_total_s                                     133.84063 |
| training_iteration                                       4 |
| accuracy                                            0.5436 |
| loss                                               1.29817 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000003
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000003)
(func pid=4880) [5,  2000] loss: 1.875
(func pid=4873) [3,  6000] loss: nan
(func pid=4879) [3,  6000] loss: 0.456 [repeated 3x across cluster]

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:29:09. Total running time: 4min 30s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        2           241.927      1.6424        0.3553 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        2           236.493    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        2           239.262      1.41163       0.4837 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        4           249.262      1.85685       0.218  |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        4           249.565      1.82043       0.2388 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        4           133.841      1.29817       0.5436 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [3,  8000] loss: nan
(func pid=4879) [3,  8000] loss: 0.340 [repeated 5x across cluster]
(func pid=4873) [3, 10000] loss: nan
(func pid=4874) [5,  4000] loss: 0.640 [repeated 4x across cluster]
(func pid=4881) [5,  8000] loss: 0.452 [repeated 2x across cluster]
(func pid=4873) [3, 12000] loss: nan

Trial train_cifar_cd391_00008 finished iteration 5 at 2025-06-17 14:29:29. Total running time: 4min 50s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  27.94115 |
| time_total_s                                     161.78177 |
| training_iteration                                       5 |
| accuracy                                            0.5228 |
| loss                                               1.34199 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000004
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000004)
(func pid=4879) [3, 12000] loss: 0.221 [repeated 3x across cluster]
(func pid=4873) [3, 14000] loss: nan

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:29:39. Total running time: 5min 0s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        2           241.927      1.6424        0.3553 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        2           236.493    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        2           239.262      1.41163       0.4837 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        4           249.262      1.85685       0.218  |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        4           249.565      1.82043       0.2388 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        5           161.782      1.34199       0.5228 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4879) [3, 14000] loss: 0.193 [repeated 4x across cluster]

Trial train_cifar_cd391_00006 finished iteration 5 at 2025-06-17 14:29:42. Total running time: 5min 4s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  49.62623 |
| time_total_s                                     299.19132 |
| training_iteration                                       5 |
| accuracy                                            0.2502 |
| loss                                                1.8066 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000004
(func pid=4881) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000004)

Trial train_cifar_cd391_00005 finished iteration 5 at 2025-06-17 14:29:43. Total running time: 5min 4s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  50.49998 |
| time_total_s                                     299.76207 |
| training_iteration                                       5 |
| accuracy                                            0.2312 |
| loss                                               1.86742 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000004
(func pid=4880) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000004)
(func pid=4873) [3, 16000] loss: nan
(func pid=4879) [3, 16000] loss: 0.162 [repeated 3x across cluster]
(func pid=4873) [3, 18000] loss: nan
(func pid=4879) [3, 18000] loss: 0.146 [repeated 5x across cluster]

Trial train_cifar_cd391_00008 finished iteration 6 at 2025-06-17 14:29:57. Total running time: 5min 18s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                  28.23322 |
| time_total_s                                     190.01499 |
| training_iteration                                       6 |
| accuracy                                            0.5242 |
| loss                                               1.38188 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000005
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000005)
(func pid=4873) [3, 20000] loss: nan
(func pid=4863) [3, 18000] loss: 0.176 [repeated 2x across cluster]
(func pid=4874) [7,  2000] loss: 1.215 [repeated 3x across cluster]

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:30:09. Total running time: 5min 30s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        2           241.927      1.6424        0.3553 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        2           236.493    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        2           239.262      1.41163       0.4837 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        5           299.762      1.86742       0.2312 |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        5           299.191      1.8066        0.2502 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        6           190.015      1.38188       0.5242 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+

Trial train_cifar_cd391_00002 finished iteration 3 at 2025-06-17 14:30:15. Total running time: 5min 36s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                  94.95772 |
| time_total_s                                     331.45033 |
| training_iteration                                       3 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000002
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000002)
(func pid=4881) [6,  8000] loss: 0.448 [repeated 4x across cluster]

Trial train_cifar_cd391_00004 finished iteration 3 at 2025-06-17 14:30:19. Total running time: 5min 40s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                  96.63778 |
| time_total_s                                     335.89941 |
| training_iteration                                       3 |
| accuracy                                            0.5044 |
| loss                                                1.3599 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000002
(func pid=4873) [4,  2000] loss: nan
(func pid=4880) [6,  8000] loss: 0.465 [repeated 2x across cluster]

Trial train_cifar_cd391_00000 finished iteration 3 at 2025-06-17 14:30:24. Total running time: 5min 45s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000002 |
| time_this_iter_s                                  99.34851 |
| time_total_s                                     341.27521 |
| training_iteration                                       3 |
| accuracy                                            0.3663 |
| loss                                                1.5871 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 3 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000002
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000002) [repeated 2x across cluster]

Trial train_cifar_cd391_00008 finished iteration 7 at 2025-06-17 14:30:26. Total running time: 5min 47s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                  28.06529 |
| time_total_s                                     218.08029 |
| training_iteration                                       7 |
| accuracy                                            0.5407 |
| loss                                               1.31831 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000006
(func pid=4873) [4,  4000] loss: nan
(func pid=4879) [4,  2000] loss: 1.265 [repeated 3x across cluster]

Trial train_cifar_cd391_00006 finished iteration 6 at 2025-06-17 14:30:32. Total running time: 5min 53s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                  49.85891 |
| time_total_s                                     349.05022 |
| training_iteration                                       6 |
| accuracy                                             0.264 |
| loss                                               1.79823 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000005
(func pid=4881) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000005) [repeated 2x across cluster]

Trial train_cifar_cd391_00005 finished iteration 6 at 2025-06-17 14:30:33. Total running time: 5min 54s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                  50.36257 |
| time_total_s                                     350.12464 |
| training_iteration                                       6 |
| accuracy                                             0.243 |
| loss                                               1.85766 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000005

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:30:39. Total running time: 6min 0s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        3           341.275      1.5871        0.3663 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        3           331.45     nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        3           335.899      1.3599        0.5044 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        6           350.125      1.85766       0.243  |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        6           349.05       1.79823       0.264  |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        7           218.08       1.31831       0.5407 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [4,  6000] loss: nan
(func pid=4879) [4,  4000] loss: 0.637 [repeated 3x across cluster]
(func pid=4874) [8,  4000] loss: 0.610 [repeated 5x across cluster]
(func pid=4873) [4,  8000] loss: nan
(func pid=4881) [7,  4000] loss: 0.886 [repeated 2x across cluster]

Trial train_cifar_cd391_00008 finished iteration 8 at 2025-06-17 14:30:54. Total running time: 6min 15s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                  28.16354 |
| time_total_s                                     246.24383 |
| training_iteration                                       8 |
| accuracy                                            0.5485 |
| loss                                               1.29488 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000007
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000007) [repeated 2x across cluster]
(func pid=4873) [4, 10000] loss: nan
(func pid=4879) [4,  8000] loss: 0.313 [repeated 2x across cluster]
(func pid=4873) [4, 12000] loss: nan
(func pid=4879) [4, 10000] loss: 0.249 [repeated 4x across cluster]
(func pid=4879) [4, 12000] loss: 0.208 [repeated 5x across cluster]

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:31:09. Total running time: 6min 30s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        3           341.275      1.5871        0.3663 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        3           331.45     nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        3           335.899      1.3599        0.5044 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        6           350.125      1.85766       0.243  |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        6           349.05       1.79823       0.264  |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        8           246.244      1.29488       0.5485 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [4, 14000] loss: nan
(func pid=4874) [9,  4000] loss: 0.607
(func pid=4881) [7, 10000] loss: 0.354
(func pid=4873) [4, 16000] loss: nan

Trial train_cifar_cd391_00008 finished iteration 9 at 2025-06-17 14:31:22. Total running time: 6min 43s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000008 |
| time_this_iter_s                                  28.42496 |
| time_total_s                                      274.6688 |
| training_iteration                                       9 |
| accuracy                                            0.5652 |
| loss                                               1.25249 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 9 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000008
(func pid=4874) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000008)

Trial train_cifar_cd391_00006 finished iteration 7 at 2025-06-17 14:31:23. Total running time: 6min 44s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                  50.82295 |
| time_total_s                                     399.87318 |
| training_iteration                                       7 |
| accuracy                                            0.2731 |
| loss                                               1.76998 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000006
(func pid=4881) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000006)
(func pid=4863) [4, 14000] loss: 0.220 [repeated 4x across cluster]

Trial train_cifar_cd391_00005 finished iteration 7 at 2025-06-17 14:31:24. Total running time: 6min 45s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                  50.63712 |
| time_total_s                                     400.76176 |
| training_iteration                                       7 |
| accuracy                                            0.2332 |
| loss                                               1.88245 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000006
(func pid=4873) [4, 18000] loss: nan
(func pid=4863) [4, 16000] loss: 0.194 [repeated 2x across cluster]
(func pid=4873) [4, 20000] loss: nan

Trial status: 6 RUNNING | 4 TERMINATED
Current time: 2025-06-17 14:31:39. Total running time: 7min 1s
Logical resource usage: 12.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        3           341.275      1.5871        0.3663 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        3           331.45     nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        3           335.899      1.3599        0.5044 |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        7           400.762      1.88245       0.2332 |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        7           399.873      1.76998       0.2731 |
| train_cifar_cd391_00008   RUNNING        32     64   0.00411186               8        9           274.669      1.25249       0.5652 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4863) [4, 18000] loss: 0.174 [repeated 5x across cluster]
(func pid=4863) [4, 20000] loss: 0.155 [repeated 5x across cluster]

Trial train_cifar_cd391_00002 finished iteration 4 at 2025-06-17 14:31:50. Total running time: 7min 11s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                  95.59439 |
| time_total_s                                     427.04472 |
| training_iteration                                       4 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000003
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000003) [repeated 2x across cluster]

Trial train_cifar_cd391_00008 finished iteration 10 at 2025-06-17 14:31:50. Total running time: 7min 11s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00008 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000009 |
| time_this_iter_s                                  28.20754 |
| time_total_s                                     302.87634 |
| training_iteration                                      10 |
| accuracy                                            0.5625 |
| loss                                               1.28149 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00008 saved a checkpoint for iteration 10 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00008_8_batch_size=8,l1=32,l2=64,lr=0.0041_2025-06-17_14-24-39/checkpoint_000009

Trial train_cifar_cd391_00008 completed after 10 iterations at 2025-06-17 14:31:50. Total running time: 7min 11s

Trial train_cifar_cd391_00004 finished iteration 4 at 2025-06-17 14:31:55. Total running time: 7min 16s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                  96.20584 |
| time_total_s                                     432.10525 |
| training_iteration                                       4 |
| accuracy                                             0.558 |
| loss                                               1.25205 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000003
(func pid=4881) [8,  8000] loss: 0.442 [repeated 3x across cluster]
(func pid=4873) [5,  2000] loss: nan

Trial train_cifar_cd391_00000 finished iteration 4 at 2025-06-17 14:32:01. Total running time: 7min 22s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000003 |
| time_this_iter_s                                  96.81904 |
| time_total_s                                     438.09425 |
| training_iteration                                       4 |
| accuracy                                            0.4058 |
| loss                                               1.57534 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 4 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000003
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000003) [repeated 3x across cluster]
(func pid=4879) [5,  2000] loss: 1.180 [repeated 2x across cluster]
(func pid=4873) [5,  4000] loss: nan
(func pid=4863) [5,  2000] loss: 1.489 [repeated 3x across cluster]

Trial status: 5 RUNNING | 5 TERMINATED
Current time: 2025-06-17 14:32:10. Total running time: 7min 31s
Logical resource usage: 10.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        4           438.094      1.57534       0.4058 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        4           427.045    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        4           432.105      1.25205       0.558  |
| train_cifar_cd391_00005   RUNNING         4      1   0.00322626               4        7           400.762      1.88245       0.2332 |
| train_cifar_cd391_00006   RUNNING         1     16   0.000669639              4        7           399.873      1.76998       0.2731 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+

Trial train_cifar_cd391_00006 finished iteration 8 at 2025-06-17 14:32:11. Total running time: 7min 32s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00006 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                  47.78776 |
| time_total_s                                     447.66093 |
| training_iteration                                       8 |
| accuracy                                            0.2635 |
| loss                                                1.7753 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00006 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000007

Trial train_cifar_cd391_00006 completed after 8 iterations at 2025-06-17 14:32:11. Total running time: 7min 32s
(func pid=4881) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00006_6_batch_size=4,l1=1,l2=16,lr=0.0007_2025-06-17_14-24-39/checkpoint_000007)

(func pid=4880) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000007)
Trial train_cifar_cd391_00005 finished iteration 8 at 2025-06-17 14:32:12. Total running time: 7min 33s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00005 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                  47.86677 |
| time_total_s                                     448.62853 |
| training_iteration                                       8 |
| accuracy                                            0.2387 |
| loss                                               1.88082 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00005 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00005_5_batch_size=4,l1=4,l2=1,lr=0.0032_2025-06-17_14-24-39/checkpoint_000007

Trial train_cifar_cd391_00005 completed after 8 iterations at 2025-06-17 14:32:12. Total running time: 7min 33s
(func pid=4873) [5,  6000] loss: nan
(func pid=4863) [5,  4000] loss: 0.769 [repeated 2x across cluster]
(func pid=4873) [5,  8000] loss: nan
(func pid=4879) [5,  8000] loss: 0.296 [repeated 2x across cluster]
(func pid=4873) [5, 10000] loss: nan
(func pid=4873) [5, 12000] loss: nan
(func pid=4863) [5,  6000] loss: 0.500
(func pid=4879) [5, 10000] loss: 0.238
(func pid=4873) [5, 14000] loss: nan
(func pid=4863) [5,  8000] loss: 0.382
(func pid=4863) [5, 10000] loss: 0.304

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:32:40. Total running time: 8min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        4           438.094      1.57534       0.4058 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        4           427.045    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        4           432.105      1.25205       0.558  |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4879) [5, 14000] loss: 0.164 [repeated 2x across cluster]
(func pid=4873) [5, 16000] loss: nan
(func pid=4879) [5, 16000] loss: 0.146 [repeated 2x across cluster]
(func pid=4873) [5, 18000] loss: nan
(func pid=4873) [5, 20000] loss: nan
(func pid=4863) [5, 14000] loss: 0.219
(func pid=4879) [5, 18000] loss: 0.132
(func pid=4879) [5, 20000] loss: 0.119 [repeated 2x across cluster]

Trial train_cifar_cd391_00002 finished iteration 5 at 2025-06-17 14:33:02. Total running time: 8min 23s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  72.11243 |
| time_total_s                                     499.15715 |
| training_iteration                                       5 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000004
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000004)
(func pid=4863) [5, 20000] loss: 0.151 [repeated 2x across cluster]

Trial train_cifar_cd391_00004 finished iteration 5 at 2025-06-17 14:33:09. Total running time: 8min 30s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  73.33042 |
| time_total_s                                     505.43567 |
| training_iteration                                       5 |
| accuracy                                            0.5559 |
| loss                                                1.2496 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000004
(func pid=4879) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000004)
(func pid=4873) [6,  2000] loss: nan

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:33:10. Total running time: 8min 31s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        4           438.094      1.57534       0.4058 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        5           499.157    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        5           505.436      1.2496        0.5559 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [6,  4000] loss: nan
(func pid=4879) [6,  2000] loss: 1.107

Trial train_cifar_cd391_00000 finished iteration 5 at 2025-06-17 14:33:17. Total running time: 8min 38s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000004 |
| time_this_iter_s                                  76.09005 |
| time_total_s                                      514.1843 |
| training_iteration                                       5 |
| accuracy                                            0.3928 |
| loss                                                1.5534 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 5 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000004
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000004)
(func pid=4873) [6,  6000] loss: nan
(func pid=4879) [6,  4000] loss: 0.559
(func pid=4873) [6,  8000] loss: nan
(func pid=4879) [6,  6000] loss: 0.370 [repeated 2x across cluster]
(func pid=4873) [6, 10000] loss: nan
(func pid=4879) [6,  8000] loss: 0.286 [repeated 2x across cluster]
(func pid=4873) [6, 12000] loss: nan
(func pid=4863) [6,  6000] loss: 0.487
(func pid=4879) [6, 10000] loss: 0.232

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:33:40. Total running time: 9min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        5           514.184      1.5534        0.3928 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        5           499.157    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        5           505.436      1.2496        0.5559 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [6, 14000] loss: nan
(func pid=4863) [6,  8000] loss: 0.367
(func pid=4879) [6, 12000] loss: 0.183
(func pid=4879) [6, 14000] loss: 0.161 [repeated 2x across cluster]
(func pid=4873) [6, 16000] loss: nan
(func pid=4879) [6, 16000] loss: 0.145
(func pid=4863) [6, 12000] loss: 0.252
(func pid=4873) [6, 18000] loss: nan
(func pid=4879) [6, 18000] loss: 0.124
(func pid=4873) [6, 20000] loss: nan
(func pid=4863) [6, 14000] loss: 0.213
(func pid=4879) [6, 20000] loss: 0.114
(func pid=4863) [6, 16000] loss: 0.182
Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:34:10. Total running time: 9min 31s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        5           514.184      1.5534        0.3928 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        5           499.157    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        5           505.436      1.2496        0.5559 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+

Trial train_cifar_cd391_00002 finished iteration 6 at 2025-06-17 14:34:12. Total running time: 9min 33s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                  69.40327 |
| time_total_s                                     568.56041 |
| training_iteration                                       6 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000005
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000005)
(func pid=4863) [6, 18000] loss: 0.166

Trial train_cifar_cd391_00004 finished iteration 6 at 2025-06-17 14:34:17. Total running time: 9min 38s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                  68.50119 |
| time_total_s                                     573.93686 |
| training_iteration                                       6 |
| accuracy                                            0.5725 |
| loss                                               1.21882 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000005
(func pid=4879) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000005)
(func pid=4873) [7,  2000] loss: nan
(func pid=4863) [6, 20000] loss: 0.149
(func pid=4873) [7,  4000] loss: nan
(func pid=4879) [7,  4000] loss: 0.531 [repeated 2x across cluster]
(func pid=4873) [7,  6000] loss: nan

Trial train_cifar_cd391_00000 finished iteration 6 at 2025-06-17 14:34:30. Total running time: 9min 51s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000005 |
| time_this_iter_s                                  72.79653 |
| time_total_s                                     586.98083 |
| training_iteration                                       6 |
| accuracy                                            0.4123 |
| loss                                               1.53784 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 6 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000005
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000005)
(func pid=4879) [7,  6000] loss: 0.356
(func pid=4873) [7,  8000] loss: nan
(func pid=4863) [7,  2000] loss: 1.400

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:34:40. Total running time: 10min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        6           586.981      1.53784       0.4123 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        6           568.56     nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        6           573.937      1.21882       0.5725 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [7, 10000] loss: nan
(func pid=4879) [7,  8000] loss: 0.270
(func pid=4863) [7,  4000] loss: 0.702
(func pid=4873) [7, 12000] loss: nan
(func pid=4879) [7, 10000] loss: 0.215
(func pid=4863) [7,  6000] loss: 0.477
(func pid=4873) [7, 14000] loss: nan
(func pid=4879) [7, 12000] loss: 0.177
(func pid=4863) [7,  8000] loss: 0.358
(func pid=4873) [7, 16000] loss: nan
(func pid=4863) [7, 10000] loss: 0.280 [repeated 2x across cluster]
(func pid=4873) [7, 18000] loss: nan
(func pid=4863) [7, 12000] loss: 0.248 [repeated 2x across cluster]
Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:35:10. Total running time: 10min 31s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        6           586.981      1.53784       0.4123 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        6           568.56     nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        6           573.937      1.21882       0.5725 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [7, 20000] loss: nan
(func pid=4863) [7, 14000] loss: 0.207 [repeated 2x across cluster]

Trial train_cifar_cd391_00002 finished iteration 7 at 2025-06-17 14:35:22. Total running time: 10min 43s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                  70.10403 |
| time_total_s                                     638.66444 |
| training_iteration                                       7 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000006
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000006)
(func pid=4863) [7, 16000] loss: 0.181 [repeated 2x across cluster]

(func pid=4879) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000006)
Trial train_cifar_cd391_00004 finished iteration 7 at 2025-06-17 14:35:26. Total running time: 10min 47s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                   69.2097 |
| time_total_s                                     643.14656 |
| training_iteration                                       7 |
| accuracy                                            0.5772 |
| loss                                               1.20163 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000006
(func pid=4863) [7, 18000] loss: 0.162
(func pid=4873) [8,  2000] loss: nan
(func pid=4879) [8,  2000] loss: 1.017
(func pid=4873) [8,  4000] loss: nan
(func pid=4879) [8,  4000] loss: 0.516 [repeated 2x across cluster]

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:35:40. Total running time: 11min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        6           586.981      1.53784       0.4123 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        7           638.664    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        7           643.147      1.20163       0.5772 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [8,  6000] loss: nan

Trial train_cifar_cd391_00000 finished iteration 7 at 2025-06-17 14:35:44. Total running time: 11min 5s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000006 |
| time_this_iter_s                                  73.45394 |
| time_total_s                                     660.43477 |
| training_iteration                                       7 |
| accuracy                                            0.3962 |
| loss                                                1.5917 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 7 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000006
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000006)
(func pid=4879) [8,  6000] loss: 0.346
(func pid=4873) [8,  8000] loss: nan
(func pid=4879) [8,  8000] loss: 0.254
(func pid=4873) [8, 10000] loss: nan
(func pid=4879) [8, 10000] loss: 0.209 [repeated 2x across cluster]
(func pid=4873) [8, 12000] loss: nan
(func pid=4879) [8, 12000] loss: 0.172 [repeated 2x across cluster]
(func pid=4873) [8, 14000] loss: nan
(func pid=4879) [8, 14000] loss: 0.153 [repeated 2x across cluster]

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:36:10. Total running time: 11min 31s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        7           660.435      1.5917        0.3962 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        7           638.664    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        7           643.147      1.20163       0.5772 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [8, 16000] loss: nan
(func pid=4879) [8, 16000] loss: 0.131 [repeated 2x across cluster]
(func pid=4873) [8, 18000] loss: nan
(func pid=4879) [8, 18000] loss: 0.120 [repeated 2x across cluster]
(func pid=4873) [8, 20000] loss: nan
(func pid=4879) [8, 20000] loss: 0.107 [repeated 2x across cluster]

Trial train_cifar_cd391_00002 finished iteration 8 at 2025-06-17 14:36:31. Total running time: 11min 52s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                   69.4461 |
| time_total_s                                     708.11054 |
| training_iteration                                       8 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000007
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000007)

Trial train_cifar_cd391_00004 finished iteration 8 at 2025-06-17 14:36:36. Total running time: 11min 57s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                  69.32206 |
| time_total_s                                     712.46862 |
| training_iteration                                       8 |
| accuracy                                            0.5864 |
| loss                                               1.17764 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000007
(func pid=4863) [8, 16000] loss: 0.178 [repeated 2x across cluster]
(func pid=4873) [9,  2000] loss: nan

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:36:40. Total running time: 12min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        7           660.435      1.5917        0.3962 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        8           708.111    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        8           712.469      1.17764       0.5864 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4879) [9,  2000] loss: 1.012
(func pid=4863) [8, 18000] loss: 0.159
(func pid=4873) [9,  4000] loss: nan
(func pid=4879) [9,  4000] loss: 0.498
(func pid=4863) [8, 20000] loss: 0.146
(func pid=4873) [9,  6000] loss: nan
(func pid=4879) [9,  6000] loss: 0.335
(func pid=4873) [9,  8000] loss: nan

Trial train_cifar_cd391_00000 finished iteration 8 at 2025-06-17 14:36:59. Total running time: 12min 20s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000007 |
| time_this_iter_s                                  75.22597 |
| time_total_s                                     735.66074 |
| training_iteration                                       8 |
| accuracy                                            0.4199 |
| loss                                               1.55898 |
+------------------------------------------------------------+(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000007) [repeated 2x across cluster]

Trial train_cifar_cd391_00000 saved a checkpoint for iteration 8 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000007
(func pid=4879) [9,  8000] loss: 0.248
(func pid=4873) [9, 10000] loss: nan
(func pid=4879) [9, 10000] loss: 0.200 [repeated 2x across cluster]
(func pid=4873) [9, 12000] loss: nan

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:37:10. Total running time: 12min 31s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        8           735.661      1.55898       0.4199 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        8           708.111    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        8           712.469      1.17764       0.5864 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4879) [9, 12000] loss: 0.163 [repeated 2x across cluster]
(func pid=4873) [9, 14000] loss: nan
(func pid=4863) [9,  6000] loss: 0.452
(func pid=4879) [9, 14000] loss: 0.146
(func pid=4873) [9, 16000] loss: nan
(func pid=4863) [9,  8000] loss: 0.355
(func pid=4879) [9, 16000] loss: 0.132
(func pid=4873) [9, 18000] loss: nan
(func pid=4863) [9, 10000] loss: 0.276
(func pid=4879) [9, 18000] loss: 0.114
(func pid=4873) [9, 20000] loss: nan
(func pid=4879) [9, 20000] loss: 0.100 [repeated 2x across cluster]
Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:37:40. Total running time: 13min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        8           735.661      1.55898       0.4199 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        8           708.111    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        8           712.469      1.17764       0.5864 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+

Trial train_cifar_cd391_00002 finished iteration 9 at 2025-06-17 14:37:43. Total running time: 13min 4s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000008 |
| time_this_iter_s                                   71.6075 |
| time_total_s                                     779.71804 |
| training_iteration                                       9 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 9 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000008
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000008)

Trial train_cifar_cd391_00004 finished iteration 9 at 2025-06-17 14:37:47. Total running time: 13min 8s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000008 |
| time_this_iter_s                                  71.27492 |
| time_total_s                                     783.74354 |
| training_iteration                                       9 |
| accuracy                                            0.5906 |
| loss                                               1.20034 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 9 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000008
(func pid=4879) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000008)
(func pid=4863) [9, 16000] loss: 0.172 [repeated 2x across cluster]
(func pid=4873) [10,  2000] loss: nan
(func pid=4863) [9, 18000] loss: 0.157 [repeated 2x across cluster]
(func pid=4873) [10,  4000] loss: nan
(func pid=4863) [9, 20000] loss: 0.140 [repeated 2x across cluster]
(func pid=4873) [10,  6000] loss: nan
(func pid=4873) [10,  8000] loss: nan
(func pid=4879) [10,  6000] loss: 0.314

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:38:10. Total running time: 13min 31s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        8           735.661      1.55898       0.4199 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        9           779.718    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        9           783.744      1.20034       0.5906 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+

Trial train_cifar_cd391_00000 finished iteration 9 at 2025-06-17 14:38:10. Total running time: 13min 31s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000008 |
| time_this_iter_s                                  71.46668 |
| time_total_s                                     807.12742 |
| training_iteration                                       9 |
| accuracy                                            0.4096 |
| loss                                                1.5911 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 9 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000008
(func pid=4879) [10,  8000] loss: 0.246
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000008)
(func pid=4873) [10, 10000] loss: nan
(func pid=4879) [10, 10000] loss: 0.196
(func pid=4863) [10,  2000] loss: 1.291
(func pid=4873) [10, 12000] loss: nan
(func pid=4879) [10, 12000] loss: 0.165
(func pid=4863) [10,  4000] loss: 0.667
(func pid=4873) [10, 14000] loss: nan
(func pid=4879) [10, 14000] loss: 0.140
(func pid=4863) [10,  6000] loss: 0.457
(func pid=4873) [10, 16000] loss: nan
(func pid=4879) [10, 16000] loss: 0.124
(func pid=4873) [10, 18000] loss: nan
(func pid=4863) [10,  8000] loss: 0.336

Trial status: 3 RUNNING | 7 TERMINATED
Current time: 2025-06-17 14:38:40. Total running time: 14min 1s
Logical resource usage: 6.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        9           807.127      1.5911        0.4096 |
| train_cifar_cd391_00002   RUNNING        16    256   0.0450584                2        9           779.718    nan             0.0986 |
| train_cifar_cd391_00004   RUNNING        64     16   0.000310926              2        9           783.744      1.20034       0.5906 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4873) [10, 20000] loss: nan
(func pid=4879) [10, 18000] loss: 0.111
(func pid=4863) [10, 10000] loss: 0.272
(func pid=4863) [10, 12000] loss: 0.228 [repeated 2x across cluster]

Trial train_cifar_cd391_00002 finished iteration 10 at 2025-06-17 14:38:52. Total running time: 14min 13s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00002 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000009 |
| time_this_iter_s                                  69.52176 |
| time_total_s                                     849.23979 |
| training_iteration                                      10 |
| accuracy                                            0.0986 |
| loss                                                   nan |
+------------------------------------------------------------+
Trial train_cifar_cd391_00002 saved a checkpoint for iteration 10 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000009

Trial train_cifar_cd391_00002 completed after 10 iterations at 2025-06-17 14:38:52. Total running time: 14min 13s
(func pid=4873) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00002_2_batch_size=2,l1=16,l2=256,lr=0.0451_2025-06-17_14-24-39/checkpoint_000009)
(func pid=4863) [10, 14000] loss: 0.198

Trial train_cifar_cd391_00004 finished iteration 10 at 2025-06-17 14:38:57. Total running time: 14min 18s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00004 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000009 |
| time_this_iter_s                                  70.55227 |
| time_total_s                                     854.29581 |
| training_iteration                                      10 |
| accuracy                                            0.5928 |
| loss                                               1.19272 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00004 saved a checkpoint for iteration 10 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00004_4_batch_size=2,l1=64,l2=16,lr=0.0003_2025-06-17_14-24-39/checkpoint_000009

Trial train_cifar_cd391_00004 completed after 10 iterations at 2025-06-17 14:38:57. Total running time: 14min 18s
(func pid=4863) [10, 16000] loss: 0.173
(func pid=4863) [10, 18000] loss: 0.153

Trial status: 1 RUNNING | 9 TERMINATED
Current time: 2025-06-17 14:39:10. Total running time: 14min 31s
Logical resource usage: 2.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   RUNNING       256      2   0.00105263               2        9           807.127      1.5911        0.4096 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00002   TERMINATED     16    256   0.0450584                2       10           849.24     nan             0.0986 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00004   TERMINATED     64     16   0.000310926              2       10           854.296      1.19272       0.5928 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+
(func pid=4863) [10, 20000] loss: 0.141

Trial train_cifar_cd391_00000 finished iteration 10 at 2025-06-17 14:39:20. Total running time: 14min 41s
+------------------------------------------------------------+
| Trial train_cifar_cd391_00000 result                       |
+------------------------------------------------------------+
| checkpoint_dir_name                      checkpoint_000009 |
| time_this_iter_s                                  69.83214 |
| time_total_s                                     876.95956 |
| training_iteration                                      10 |
| accuracy                                            0.4194 |
| loss                                               1.62979 |
+------------------------------------------------------------+
Trial train_cifar_cd391_00000 saved a checkpoint for iteration 10 at: (local)/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000009

Trial train_cifar_cd391_00000 completed after 10 iterations at 2025-06-17 14:39:20. Total running time: 14min 41s
(func pid=4863) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/var/lib/ci-user/ray_results/train_cifar_2025-06-17_14-24-38/train_cifar_cd391_00000_0_batch_size=2,l1=256,l2=2,lr=0.0011_2025-06-17_14-24-39/checkpoint_000009) [repeated 2x across cluster]

Trial status: 10 TERMINATED
Current time: 2025-06-17 14:39:20. Total running time: 14min 41s
Logical resource usage: 2.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:A10G)
+--------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                status         l1     l2            lr     batch_size     iter     total time (s)        loss     accuracy |
+--------------------------------------------------------------------------------------------------------------------------------------+
| train_cifar_cd391_00000   TERMINATED    256      2   0.00105263               2       10           876.96       1.62979       0.4194 |
| train_cifar_cd391_00001   TERMINATED     64      2   0.0189753               16       10           215.319      1.87434       0.2731 |
| train_cifar_cd391_00002   TERMINATED     16    256   0.0450584                2       10           849.24     nan             0.0986 |
| train_cifar_cd391_00003   TERMINATED      8     16   0.00920872               2        1           124.226      2.3138        0.0994 |
| train_cifar_cd391_00004   TERMINATED     64     16   0.000310926              2       10           854.296      1.19272       0.5928 |
| train_cifar_cd391_00005   TERMINATED      4      1   0.00322626               4        8           448.629      1.88082       0.2387 |
| train_cifar_cd391_00006   TERMINATED      1     16   0.000669639              4        8           447.661      1.7753        0.2635 |
| train_cifar_cd391_00007   TERMINATED     64      1   0.00143856               2        1           126.116      2.30463       0.0995 |
| train_cifar_cd391_00008   TERMINATED     32     64   0.00411186               8       10           302.876      1.28149       0.5625 |
| train_cifar_cd391_00009   TERMINATED    256     64   0.000399319              8        1            43.2076     1.95413       0.2744 |
+--------------------------------------------------------------------------------------------------------------------------------------+

Best trial config: {'l1': 64, 'l2': 16, 'lr': 0.00031092580906260324, 'batch_size': 2}
Best trial final validation loss: 1.1927188325223512
Best trial final validation accuracy: 0.5928
Best trial test set accuracy: 0.5935

If you run the code, an example output could look like this:

Number of trials: 10/10 (10 TERMINATED)
+-----+--------------+------+------+-------------+--------+---------+------------+
| ... |   batch_size |   l1 |   l2 |          lr |   iter |    loss |   accuracy |
|-----+--------------+------+------+-------------+--------+---------+------------|
| ... |            2 |    1 |  256 | 0.000668163 |      1 | 2.31479 |     0.0977 |
| ... |            4 |   64 |    8 | 0.0331514   |      1 | 2.31605 |     0.0983 |
| ... |            4 |    2 |    1 | 0.000150295 |      1 | 2.30755 |     0.1023 |
| ... |           16 |   32 |   32 | 0.0128248   |     10 | 1.66912 |     0.4391 |
| ... |            4 |    8 |  128 | 0.00464561  |      2 | 1.7316  |     0.3463 |
| ... |            8 |  256 |    8 | 0.00031556  |      1 | 2.19409 |     0.1736 |
| ... |            4 |   16 |  256 | 0.00574329  |      2 | 1.85679 |     0.3368 |
| ... |            8 |    2 |    2 | 0.00325652  |      1 | 2.30272 |     0.0984 |
| ... |            2 |    2 |    2 | 0.000342987 |      2 | 1.76044 |     0.292  |
| ... |            4 |   64 |   32 | 0.003734    |      8 | 1.53101 |     0.4761 |
+-----+--------------+------+------+-------------+--------+---------+------------+

Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
Best trial final validation loss: 1.5310075663924216
Best trial final validation accuracy: 0.4761
Best trial test set accuracy: 0.4737

Most trials have been stopped early in order to avoid wasting resources. The best performing trial achieved a validation accuracy of about 47%, which could be confirmed on the test set.

So that’s it! You can now tune the parameters of your PyTorch models.

Total running time of the script: ( 14 minutes 55.556 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources