Skip to content

Train ArchesWeather and ArchesWeatherGen

This section provides the full training pipeline to retrain ArchesWeatherGen from scratch. It assumes that you have already installed the geoarches package and downloaded the necessary data.

Before you start

You can define the following aliases in your shell config (e.g. .bashrc or .zshrc) to simplify the commands:

alias train=python -m geoarches.main_hydra ++log=True
alias test=python -m geoarches.main_hydra ++mode=test
alias strain=python -m geoarches.submit

Step 1. Training ArchesWeather

First, we train four deterministic versions of ArchesWeather on ERA5 data:

for i in {0..3}; do
    train dataloader=era5 module=archesweather ++name=archesweather-m-seed$i
done

Success

Each model will save its configuration to modelstore/archesweather-m-seed{i}/config.yaml and checkpoints under modelstore/archesweather-m-seed$i/checkpoints/.

Info

In the released checkpoints, two models include skip connections, but that should not really matter.

Note

ArchesWeatherGen does not require multistep fine-tuning.

Step 2. Compute residuals on the ERA5 dataset

Since ArchesWeatherGen models residuals, we can pre-compute them on the full dataset to speed up training:

python -m geoarches.inference.encode_dataset \
    --uids archesweather-m-seed0,archesweather-m-seed1,archesweather-m-seed2,archesweather-m-seed3 \
    --output-path data/outputs/deterministic/archesweather-m4/

Step 3. Training ArchesWeatherGen

Once residuals are computed, we train the flow matching model:

M4ARGS="++dataloader.dataset.pred_path=data/outputs/deterministic/archesweather-m4 \
++module.module.load_deterministic_model=[archesweather-m-seed0,archesweather-m-seed1,archesweather-m-seed2,archesweather-m-seed3] "

train module=archesweathergen dataloader=era5pred \
    ++limit_val_batches=10 \
    ++max_steps=200000 \
    ++name=archesweathergen-s \
    $M4ARGS \
    ++seed=0

Step 4. Fine-tuning ArchesWeatherGen

In the paper, we fine-tune the model on 2019 data to overcome overfitting of the deterministic models. See the paper for more details.

train module=archesweathergen dataloader=era5pred \
    ++limit_val_batches=10 ++max_steps=60000 \
    "++name=archesweathergen-s-ft" \
    $M4ARGS \
    "++load_ckpt=modelstore/archesweathergen-s" \
    "++ckpt_filename_match=200000" \ # for loading the checkpoint at 200k steps
    "++dataloader.dataset.domain=val" \ # fine-tune on validation
    "++module.module.lr=1e-4" \
    "++module.module.num_warmup_steps=500" \
    "++module.module.betas=[0.95, 0.99]" \
    "++save_step_frequency=20000"

Step 5. Evaluation

Finally, we can evaluate the saved model:

multistep=10

test ++name=archesweathergen-s-ft \
    ++limit_test_batches=0.1 \ # optional, for running on fewer members
    ++dataloader.test_args.multistep=$multistep \
    ++module.inference.save_test_outputs=True \ # can be set to False to not save forecasts \
    ++module.inference.rollout_iterations=$multistep \
    ++module.inference.num_steps=25 \ # number of diffusion steps
    ++module.inference.num_members=50 \
    ++module.inference.scale_input_noise=1.05  # to use noise scaling