fennol.training.optimizers

  1from typing import (
  2    Callable,
  3    Optional,
  4    Dict,
  5    List,
  6    Tuple,
  7    Union,
  8    Any,
  9    NamedTuple,
 10    Sequence,
 11)
 12import optax
 13import jax
 14import jax.numpy as jnp
 15import numpy as np
 16import operator
 17from flax import traverse_util
 18import json
 19import re
 20
 21from optax._src import base
 22from optax._src import wrappers
 23from optax import tree_utils as otu
 24
 25
 26class AddWeightDiffState(NamedTuple):
 27    ref_weights: Any
 28
 29
 30def add_weights_difference(
 31    weight_decay: Union[float, jax.Array] = 0.0,
 32    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
 33) -> base.GradientTransformation:
 34    """weight decay toward initial weights."""
 35
 36    def init_fn(params):
 37        return AddWeightDiffState(ref_weights=params)
 38
 39    def update_fn(updates, state, params):
 40        if params is None:
 41            raise ValueError(base.NO_PARAMS_MSG)
 42        updates = jax.tree_util.tree_map(
 43            lambda g, p, pref: g + weight_decay * (p - pref),
 44            updates,
 45            params,
 46            state.ref_weights,
 47        )
 48        return updates, state
 49
 50    # If mask is not `None`, apply mask to the gradient transformation.
 51    # E.g. it is common to skip weight decay on bias units and batch stats.
 52    if mask is not None:
 53        return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask)
 54    return base.GradientTransformation(init_fn, update_fn)
 55
 56
 57def add_grokfast(
 58    alpha: float = 0.9,
 59    l: float = 1.0,
 60) -> base.GradientTransformation:
 61    """Grokfast: amplify slow gradients by exponential moving average."""
 62
 63    ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False)
 64
 65    def init_fn(params):
 66        return ema.init(params)
 67
 68    def update_fn(updates, state, params=None):
 69        dupdates, state = ema.update(updates, state, params)
 70        # updates = updates + l*dupdates
 71        updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates)
 72        return updates, state
 73
 74    return base.GradientTransformation(init_fn, update_fn)
 75
 76
 77class PROFITState(NamedTuple):
 78    ref_weights: Any
 79    istep: int
 80    main_opt_state: Any
 81    internal_opt_state: Any
 82
 83
 84def profit(
 85    learning_rate: base.ScalarOrSchedule,
 86    nsteps_ref: int = 1,
 87    main_opt: str = "adam",
 88    main_opt_params: Dict[str, Any] = {},
 89    internal_opt: str = "sgd",
 90    internal_opt_params: Dict[str, Any] = {},
 91    **kwargs,
 92) -> base.GradientTransformation:
 93    """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930"""
 94
 95    main_opt_params = {"learning_rate": learning_rate, **main_opt_params}
 96    main_opt = eval(
 97        main_opt,
 98        {"__builtins__": None},
 99        {**optax.__dict__},
100    )(**main_opt_params)
101
102    internal_opt_params = {"learning_rate": 0.1, **internal_opt_params}
103    internal_opt = eval(
104        internal_opt,
105        {"__builtins__": None},
106        {**optax.__dict__},
107    )(**internal_opt_params)
108
109    def init_fn(params):
110        return PROFITState(
111            ref_weights=params,
112            istep=0,
113            main_opt_state=main_opt.init(params),
114            internal_opt_state=internal_opt.init(params),
115        )
116
117    def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref):
118        delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref)
119        dot = jax.tree.reduce(
120            operator.add,
121            jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta),
122        )
123        delta2 = jax.tree.reduce(
124            operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta)
125        )
126        proj = dot / (delta2 + 1.0e-6)
127
128        gradients = jax.lax.cond(
129            dot >= 0,
130            lambda g, d: g,
131            lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d),
132            gradients,
133            delta,
134        )
135        updates, main_opt_state = main_opt.update(gradients, main_opt_state, params)
136        updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta)
137        return updates, main_opt_state, internal_opt_state
138
139    def update_internal(
140        gradients, main_opt_state, internal_opt_state, params, params_ref
141    ):
142        updates, internal_opt_state = internal_opt.update(
143            gradients, internal_opt_state, params
144        )
145        return updates, main_opt_state, internal_opt_state
146
147    def update_fn(gradients, state, params):
148        istep = state.istep % (nsteps_ref + 1)
149        # jax.debug.print("{i} {j}",i=istep,j=state.istep)
150
151        params_ref = jax.lax.cond(
152            istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights
153        )
154
155        updates, main_opt_state, internal_opt_state = jax.lax.cond(
156            istep == nsteps_ref,
157            update_main,
158            update_internal,
159            gradients,
160            state.main_opt_state,
161            state.internal_opt_state,
162            params,
163            params_ref,
164        )
165
166        new_state = PROFITState(
167            ref_weights=params_ref,
168            istep=state.istep + 1,
169            main_opt_state=main_opt_state,
170            internal_opt_state=internal_opt_state,
171        )
172        return updates, new_state
173
174    return base.GradientTransformation(init_fn, update_fn)
175
176
177class MultiEmaState(NamedTuple):
178    """Holds an exponential moving average of past updates."""
179
180    count: int
181    ema: Sequence[base.Params]
182
183
184def multi_ema(
185    decays: Sequence[float],
186    debias: bool = True,
187    power: Union[bool, Sequence[bool]] = False,
188) -> base.GradientTransformation:
189    """Compute mutliple power moving averages of past updates."""
190
191    if len(decays) == 0:
192
193        def init_fn(params):
194            return base.EmptyState()
195
196        def update_fn(updates, state, params=None):
197            return [updates], state
198
199        return base.GradientTransformation(init_fn, update_fn)
200
201    if isinstance(power, bool):
202        power = [power] * len(decays)
203    assert len(power) == len(decays), "power and decays must have the same length"
204
205    gammas = []
206    for decay, p in zip(decays, power):
207        if not p:
208            gammas.append(None)
209            continue
210        t = decay**-2
211        gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max()
212        assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}"
213        gammas.append(gamma)
214
215    def init_fn(params):
216        return MultiEmaState(
217            count=0,
218            ema=[otu.tree_zeros_like(params)] * len(decays),
219        )
220
221    def update_fn(params, state, other=None):
222        count_inc = state.count + 1
223        updates = []
224        state_ema = []
225        for decay, gamma, ema in zip(decays, gammas, state.ema):
226            if gamma is not None:
227                decay = (1.0 - 1.0 / count_inc) ** (gamma + 1)
228            update = new_ema = otu.tree_update_moment(params, ema, decay, order=1)
229            if debias and gamma is None:
230                update = otu.tree_bias_correction(update, decay, count_inc)
231            updates.append(update)
232            state_ema.append(new_ema)
233        return updates, MultiEmaState(count=count_inc, ema=state_ema)
234
235    return base.GradientTransformation(init_fn, update_fn)
236
237
238def get_optimizer(
239    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
240) -> optax.GradientTransformation:
241    """
242    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
243
244    Args:
245    - training_parameters: A dictionary containing the training parameters.
246    - variables: A  pytree containing the model parameters.
247    - initial_lr: The initial learning rate.
248
249    Returns:
250    - An optax.GradientTransformation object that can be used to optimize the model parameters.
251    """
252
253    default_status = str(training_parameters.get("default_status", "trainable")).lower()
254    assert default_status in [
255        "trainable",
256        "frozen",
257    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
258
259    # find frozen and trainable parameters
260    frozen = training_parameters.get("frozen", [])
261    trainable = training_parameters.get("trainable", [])
262
263    def training_status(full_path, v):
264        full_path = "/".join(full_path[1:]).lower()
265        # full_path = "/".join(full_path[1:]).lower()
266        status = (default_status, "")
267        for path in frozen:
268            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
269            if re.match(path.lower(), full_path):
270                status = ("frozen", path)
271        for path in trainable:
272            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
273            if re.match(path.lower(), full_path):
274                status = ("trainable", path)
275        return status[0]
276
277    params_partition = traverse_util.path_aware_map(training_status, variables)
278    if len(frozen) > 0 or len(trainable) > 0:
279        print("params partition:")
280        print(json.dumps(params_partition, indent=2, sort_keys=False))
281
282    ## Gradient preprocessing
283    grad_processing = []
284
285    # zero nans
286    zero_nans = training_parameters.get("zero_nans", False)
287    if zero_nans:
288        grad_processing.append(optax.zero_nans())
289
290    # gradient clipping
291    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
292    clip_range = str(training_parameters.get("clipping_range", "global")).lower()
293    if clip_threshold > 0.0:
294        print("gradient norm clipping threshold:", clip_threshold)
295        if clip_range == "local":
296            _clip = optax.clip
297        elif clip_range == "global":
298            _clip = optax.clip_by_global_norm
299        elif clip_range == "block":
300            _clip = optax.clip_by_block_rms
301        else:
302            raise ValueError(f"Invalid gradnorm_clipping_range: {clip_range}. Should be 'local', 'global' or 'block'.")
303
304        grad_processing.append(_clip(clip_threshold))
305
306    use_grokfast = training_parameters.get("use_grokfast", False)
307    if use_grokfast:
308        print("Using Grokfast")
309        alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9)
310        l_grokfast = training_parameters.get("l_grokfast", 1.0)
311        grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast))
312
313    # OPTIMIZER
314    optimizer_name = training_parameters.get("optimizer", "adabelief")
315    optimizer = eval(
316        optimizer_name,
317        {"__builtins__": None},
318        {
319            **optax.__dict__,
320            "profit": profit,
321        },
322    )
323    print("Optimizer:", optimizer_name)
324    optimizer_configuration = training_parameters.get("optimizer_config", {})
325    optimizer_configuration["learning_rate"] = 1.0
326    grad_processing.append(optimizer(**optimizer_configuration))
327
328    # weight decay
329    weight_decay = training_parameters.get("weight_decay", 0.0)
330    assert weight_decay >= 0.0, "Weight decay must be positive"
331    decay_targets = training_parameters.get("decay_targets", [""])
332
333    def decay_status(full_path, v):
334        full_path = "/".join(full_path).lower()
335        status = False
336        # print(full_path,re.match(r'^params\/', full_path))
337        for path in decay_targets:
338            if re.match(r"^params/" + path.lower(), full_path):
339                status = True
340            # if full_path.startswith("params/" + path.lower()):
341            #     status = True
342        return status
343
344    decay_mask = traverse_util.path_aware_map(decay_status, variables)
345    if weight_decay > 0.0:
346        print("weight decay:", weight_decay)
347        print(json.dumps(decay_mask, indent=2, sort_keys=False))
348        grad_processing.append(
349            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
350        )
351
352    regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0)
353    if regularize_init_weight > 0.0:
354        print(
355            "Regularizing toward initial weights with L2 norm:", regularize_init_weight
356        )
357        if weight_decay <= 0.0:
358            print(json.dumps(decay_mask, indent=2, sort_keys=False))
359
360        grad_processing.append(
361            add_weights_difference(
362                weight_decay=-regularize_init_weight, mask=decay_mask
363            )
364        )
365
366    if zero_nans:
367        grad_processing.append(optax.zero_nans())
368
369    # learning rate
370    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
371    ilr = -1
372
373    # gradient clipping
374    clip_threshold = training_parameters.get("adaptive_gradient_clipping", -1.0)
375    if clip_threshold > 0.0:
376        print("Adaptive gradient clipping threshold:", clip_threshold)
377        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
378        ilr -= 1
379
380    ## define optimizer chain
381    optimizer_ = optax.chain(
382        *grad_processing,
383    )
384    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
385    return optax.multi_transform(partition_optimizer, params_partition), ilr
386
387
388def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
389    lr = training_parameters.get("lr", 1.0e-3)
390    init_lr = training_parameters.get("init_lr", lr / 25)
391    final_lr = training_parameters.get("final_lr", lr / 10000)
392
393    #### LEARNING RATE SCHEDULER ####
394    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
395    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
396    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
397
398    adaptive_scheduler = False
399    print("Schedule type:", schedule_type)
400    if schedule_type == "cosine_onecycle":
401        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
402        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
403        linear_warmup = training_parameters.get("linear_warmup", False)
404
405        if linear_warmup:
406            schedule_ = optax.warmup_cosine_decay_schedule(
407                init_value=init_lr,
408                peak_value=lr,
409                warmup_steps=peak_epoch * nbatch_per_epoch,
410                decay_steps=transition_epochs * nbatch_per_epoch,
411                end_value=final_lr,
412            )
413        else:
414            schedule_ = optax.cosine_onecycle_schedule(
415                peak_value=lr,
416                div_factor=lr / init_lr,
417                final_div_factor=init_lr / final_lr,
418                transition_steps=transition_epochs * nbatch_per_epoch,
419                pct_start=peak_epoch / transition_epochs,
420            )
421        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
422
423        def schedule(state, rmse=None):
424            new_state = {**state}
425            lr = schedule_(state["count"])
426            if rmse is None:
427                new_state["count"] += 1
428            new_state["lr"] = lr
429            return lr, new_state
430
431    elif schedule_type == "piecewise_interpolate":
432        schedule_params = training_parameters.get("scheduler_parameters", {})
433        schedule_ = optax.piecewise_interpolate_schedule(
434            **{"init_value": lr, "interpolate_type": "linear", **schedule_params}
435        )
436        sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)}
437
438        def schedule(state, rmse=None):
439            new_state = {**state}
440            lr = schedule_(state["count"])
441            if rmse is None:
442                new_state["count"] += 1
443            new_state["lr"] = lr
444            return lr, new_state
445
446    elif schedule_type == "constant":
447        sch_state = {"count": 0}
448
449        def schedule(state, rmse=None):
450            new_state = {**state}
451            new_state["lr"] = lr
452            if rmse is None:
453                new_state["count"] += 1
454            return lr, new_state
455
456    elif schedule_type == "cosine":
457        assert (
458            "peak_epoch" in training_parameters
459        ), "Sine schedule requires 'peak_epoch' parameter"
460        period = training_parameters["peak_epoch"] * nbatch_per_epoch
461        peak_lr = lr
462        sch_state = {"count": 0}
463
464        def schedule(state, rmse=None):
465            new_state = {**state}
466            istep = state["count"]
467            g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period))
468            lr = peak_lr + (init_lr - peak_lr) * g
469            new_state["lr"] = lr
470            if rmse is None:
471                new_state["count"] += 1
472            return lr, new_state
473
474    elif schedule_type == "reduce_on_plateau":
475        patience = training_parameters.get("patience", 10)
476        factor = training_parameters.get("lr_factor", 0.5)
477        patience_thr = training_parameters.get("patience_thr", 0.0)
478        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
479        adaptive_scheduler = True
480
481        def schedule(state, rmse=None):
482            new_state = {**state}
483            if rmse is None:
484                new_state["count"] += 1
485                return state["lr"], new_state
486            if rmse <= state["best"] * (1.0 + patience_thr):
487                if rmse < state["best"]:
488                    new_state["best"] = rmse
489                new_state["patience"] = 0
490            else:
491                new_state["patience"] += 1
492                if new_state["patience"] >= patience:
493                    new_state["lr"] = state["lr"] * factor
494                    new_state["patience"] = 0
495                    print("Reducing learning rate to", new_state["lr"])
496            return new_state["lr"], new_state
497
498    else:
499        raise ValueError(f"Unknown schedule_type: {schedule_type}")
500
501    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
502    if stochastic_scheduler:
503        schedule_ = schedule
504        rng_key, scheduler_key = jax.random.split(rng_key)
505        sch_state["rng_key"] = scheduler_key
506        sch_state["lr_max"] = lr
507        sch_state["lr_min"] = final_lr
508
509        def schedule(state, rmse=None):
510            new_state = {**state, "lr": state["lr_max"]}
511            if rmse is None:
512                lr_max, new_state = schedule_(new_state, rmse=rmse)
513                lr_min = new_state["lr_min"]
514                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
515                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
516                new_state["lr"] = lr
517                new_state["lr_max"] = lr_max
518
519            return new_state["lr"], new_state
520
521    return schedule, sch_state, schedule_metrics, adaptive_scheduler
class AddWeightDiffState(typing.NamedTuple):
27class AddWeightDiffState(NamedTuple):
28    ref_weights: Any

AddWeightDiffState(ref_weights,)

AddWeightDiffState(ref_weights: Any)

Create new instance of AddWeightDiffState(ref_weights,)

ref_weights: Any

Alias for field number 0

def add_weights_difference( weight_decay: Union[float, jax.Array] = 0.0, mask: Union[Any, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Any], NoneType] = None) -> optax._src.base.GradientTransformation:
31def add_weights_difference(
32    weight_decay: Union[float, jax.Array] = 0.0,
33    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
34) -> base.GradientTransformation:
35    """weight decay toward initial weights."""
36
37    def init_fn(params):
38        return AddWeightDiffState(ref_weights=params)
39
40    def update_fn(updates, state, params):
41        if params is None:
42            raise ValueError(base.NO_PARAMS_MSG)
43        updates = jax.tree_util.tree_map(
44            lambda g, p, pref: g + weight_decay * (p - pref),
45            updates,
46            params,
47            state.ref_weights,
48        )
49        return updates, state
50
51    # If mask is not `None`, apply mask to the gradient transformation.
52    # E.g. it is common to skip weight decay on bias units and batch stats.
53    if mask is not None:
54        return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask)
55    return base.GradientTransformation(init_fn, update_fn)

weight decay toward initial weights.

def add_grokfast( alpha: float = 0.9, l: float = 1.0) -> optax._src.base.GradientTransformation:
58def add_grokfast(
59    alpha: float = 0.9,
60    l: float = 1.0,
61) -> base.GradientTransformation:
62    """Grokfast: amplify slow gradients by exponential moving average."""
63
64    ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False)
65
66    def init_fn(params):
67        return ema.init(params)
68
69    def update_fn(updates, state, params=None):
70        dupdates, state = ema.update(updates, state, params)
71        # updates = updates + l*dupdates
72        updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates)
73        return updates, state
74
75    return base.GradientTransformation(init_fn, update_fn)

Grokfast: amplify slow gradients by exponential moving average.

class PROFITState(typing.NamedTuple):
78class PROFITState(NamedTuple):
79    ref_weights: Any
80    istep: int
81    main_opt_state: Any
82    internal_opt_state: Any

PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)

PROFITState( ref_weights: Any, istep: int, main_opt_state: Any, internal_opt_state: Any)

Create new instance of PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)

ref_weights: Any

Alias for field number 0

istep: int

Alias for field number 1

main_opt_state: Any

Alias for field number 2

internal_opt_state: Any

Alias for field number 3

def profit( learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]]], nsteps_ref: int = 1, main_opt: str = 'adam', main_opt_params: Dict[str, Any] = {}, internal_opt: str = 'sgd', internal_opt_params: Dict[str, Any] = {}, **kwargs) -> optax._src.base.GradientTransformation:
 85def profit(
 86    learning_rate: base.ScalarOrSchedule,
 87    nsteps_ref: int = 1,
 88    main_opt: str = "adam",
 89    main_opt_params: Dict[str, Any] = {},
 90    internal_opt: str = "sgd",
 91    internal_opt_params: Dict[str, Any] = {},
 92    **kwargs,
 93) -> base.GradientTransformation:
 94    """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930"""
 95
 96    main_opt_params = {"learning_rate": learning_rate, **main_opt_params}
 97    main_opt = eval(
 98        main_opt,
 99        {"__builtins__": None},
100        {**optax.__dict__},
101    )(**main_opt_params)
102
103    internal_opt_params = {"learning_rate": 0.1, **internal_opt_params}
104    internal_opt = eval(
105        internal_opt,
106        {"__builtins__": None},
107        {**optax.__dict__},
108    )(**internal_opt_params)
109
110    def init_fn(params):
111        return PROFITState(
112            ref_weights=params,
113            istep=0,
114            main_opt_state=main_opt.init(params),
115            internal_opt_state=internal_opt.init(params),
116        )
117
118    def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref):
119        delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref)
120        dot = jax.tree.reduce(
121            operator.add,
122            jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta),
123        )
124        delta2 = jax.tree.reduce(
125            operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta)
126        )
127        proj = dot / (delta2 + 1.0e-6)
128
129        gradients = jax.lax.cond(
130            dot >= 0,
131            lambda g, d: g,
132            lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d),
133            gradients,
134            delta,
135        )
136        updates, main_opt_state = main_opt.update(gradients, main_opt_state, params)
137        updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta)
138        return updates, main_opt_state, internal_opt_state
139
140    def update_internal(
141        gradients, main_opt_state, internal_opt_state, params, params_ref
142    ):
143        updates, internal_opt_state = internal_opt.update(
144            gradients, internal_opt_state, params
145        )
146        return updates, main_opt_state, internal_opt_state
147
148    def update_fn(gradients, state, params):
149        istep = state.istep % (nsteps_ref + 1)
150        # jax.debug.print("{i} {j}",i=istep,j=state.istep)
151
152        params_ref = jax.lax.cond(
153            istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights
154        )
155
156        updates, main_opt_state, internal_opt_state = jax.lax.cond(
157            istep == nsteps_ref,
158            update_main,
159            update_internal,
160            gradients,
161            state.main_opt_state,
162            state.internal_opt_state,
163            params,
164            params_ref,
165        )
166
167        new_state = PROFITState(
168            ref_weights=params_ref,
169            istep=state.istep + 1,
170            main_opt_state=main_opt_state,
171            internal_opt_state=internal_opt_state,
172        )
173        return updates, new_state
174
175    return base.GradientTransformation(init_fn, update_fn)

PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930

class MultiEmaState(typing.NamedTuple):
178class MultiEmaState(NamedTuple):
179    """Holds an exponential moving average of past updates."""
180
181    count: int
182    ema: Sequence[base.Params]

Holds an exponential moving average of past updates.

MultiEmaState( count: int, ema: Sequence[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]])

Create new instance of MultiEmaState(count, ema)

count: int

Alias for field number 0

ema: Sequence[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]]

Alias for field number 1

def multi_ema( decays: Sequence[float], debias: bool = True, power: Union[bool, Sequence[bool]] = False) -> optax._src.base.GradientTransformation:
185def multi_ema(
186    decays: Sequence[float],
187    debias: bool = True,
188    power: Union[bool, Sequence[bool]] = False,
189) -> base.GradientTransformation:
190    """Compute mutliple power moving averages of past updates."""
191
192    if len(decays) == 0:
193
194        def init_fn(params):
195            return base.EmptyState()
196
197        def update_fn(updates, state, params=None):
198            return [updates], state
199
200        return base.GradientTransformation(init_fn, update_fn)
201
202    if isinstance(power, bool):
203        power = [power] * len(decays)
204    assert len(power) == len(decays), "power and decays must have the same length"
205
206    gammas = []
207    for decay, p in zip(decays, power):
208        if not p:
209            gammas.append(None)
210            continue
211        t = decay**-2
212        gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max()
213        assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}"
214        gammas.append(gamma)
215
216    def init_fn(params):
217        return MultiEmaState(
218            count=0,
219            ema=[otu.tree_zeros_like(params)] * len(decays),
220        )
221
222    def update_fn(params, state, other=None):
223        count_inc = state.count + 1
224        updates = []
225        state_ema = []
226        for decay, gamma, ema in zip(decays, gammas, state.ema):
227            if gamma is not None:
228                decay = (1.0 - 1.0 / count_inc) ** (gamma + 1)
229            update = new_ema = otu.tree_update_moment(params, ema, decay, order=1)
230            if debias and gamma is None:
231                update = otu.tree_bias_correction(update, decay, count_inc)
232            updates.append(update)
233            state_ema.append(new_ema)
234        return updates, MultiEmaState(count=count_inc, ema=state_ema)
235
236    return base.GradientTransformation(init_fn, update_fn)

Compute mutliple power moving averages of past updates.

def get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
239def get_optimizer(
240    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
241) -> optax.GradientTransformation:
242    """
243    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
244
245    Args:
246    - training_parameters: A dictionary containing the training parameters.
247    - variables: A  pytree containing the model parameters.
248    - initial_lr: The initial learning rate.
249
250    Returns:
251    - An optax.GradientTransformation object that can be used to optimize the model parameters.
252    """
253
254    default_status = str(training_parameters.get("default_status", "trainable")).lower()
255    assert default_status in [
256        "trainable",
257        "frozen",
258    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
259
260    # find frozen and trainable parameters
261    frozen = training_parameters.get("frozen", [])
262    trainable = training_parameters.get("trainable", [])
263
264    def training_status(full_path, v):
265        full_path = "/".join(full_path[1:]).lower()
266        # full_path = "/".join(full_path[1:]).lower()
267        status = (default_status, "")
268        for path in frozen:
269            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
270            if re.match(path.lower(), full_path):
271                status = ("frozen", path)
272        for path in trainable:
273            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
274            if re.match(path.lower(), full_path):
275                status = ("trainable", path)
276        return status[0]
277
278    params_partition = traverse_util.path_aware_map(training_status, variables)
279    if len(frozen) > 0 or len(trainable) > 0:
280        print("params partition:")
281        print(json.dumps(params_partition, indent=2, sort_keys=False))
282
283    ## Gradient preprocessing
284    grad_processing = []
285
286    # zero nans
287    zero_nans = training_parameters.get("zero_nans", False)
288    if zero_nans:
289        grad_processing.append(optax.zero_nans())
290
291    # gradient clipping
292    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
293    clip_range = str(training_parameters.get("clipping_range", "global")).lower()
294    if clip_threshold > 0.0:
295        print("gradient norm clipping threshold:", clip_threshold)
296        if clip_range == "local":
297            _clip = optax.clip
298        elif clip_range == "global":
299            _clip = optax.clip_by_global_norm
300        elif clip_range == "block":
301            _clip = optax.clip_by_block_rms
302        else:
303            raise ValueError(f"Invalid gradnorm_clipping_range: {clip_range}. Should be 'local', 'global' or 'block'.")
304
305        grad_processing.append(_clip(clip_threshold))
306
307    use_grokfast = training_parameters.get("use_grokfast", False)
308    if use_grokfast:
309        print("Using Grokfast")
310        alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9)
311        l_grokfast = training_parameters.get("l_grokfast", 1.0)
312        grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast))
313
314    # OPTIMIZER
315    optimizer_name = training_parameters.get("optimizer", "adabelief")
316    optimizer = eval(
317        optimizer_name,
318        {"__builtins__": None},
319        {
320            **optax.__dict__,
321            "profit": profit,
322        },
323    )
324    print("Optimizer:", optimizer_name)
325    optimizer_configuration = training_parameters.get("optimizer_config", {})
326    optimizer_configuration["learning_rate"] = 1.0
327    grad_processing.append(optimizer(**optimizer_configuration))
328
329    # weight decay
330    weight_decay = training_parameters.get("weight_decay", 0.0)
331    assert weight_decay >= 0.0, "Weight decay must be positive"
332    decay_targets = training_parameters.get("decay_targets", [""])
333
334    def decay_status(full_path, v):
335        full_path = "/".join(full_path).lower()
336        status = False
337        # print(full_path,re.match(r'^params\/', full_path))
338        for path in decay_targets:
339            if re.match(r"^params/" + path.lower(), full_path):
340                status = True
341            # if full_path.startswith("params/" + path.lower()):
342            #     status = True
343        return status
344
345    decay_mask = traverse_util.path_aware_map(decay_status, variables)
346    if weight_decay > 0.0:
347        print("weight decay:", weight_decay)
348        print(json.dumps(decay_mask, indent=2, sort_keys=False))
349        grad_processing.append(
350            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
351        )
352
353    regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0)
354    if regularize_init_weight > 0.0:
355        print(
356            "Regularizing toward initial weights with L2 norm:", regularize_init_weight
357        )
358        if weight_decay <= 0.0:
359            print(json.dumps(decay_mask, indent=2, sort_keys=False))
360
361        grad_processing.append(
362            add_weights_difference(
363                weight_decay=-regularize_init_weight, mask=decay_mask
364            )
365        )
366
367    if zero_nans:
368        grad_processing.append(optax.zero_nans())
369
370    # learning rate
371    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
372    ilr = -1
373
374    # gradient clipping
375    clip_threshold = training_parameters.get("adaptive_gradient_clipping", -1.0)
376    if clip_threshold > 0.0:
377        print("Adaptive gradient clipping threshold:", clip_threshold)
378        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
379        ilr -= 1
380
381    ## define optimizer chain
382    optimizer_ = optax.chain(
383        *grad_processing,
384    )
385    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
386    return optax.multi_transform(partition_optimizer, params_partition), ilr

Returns an optax.GradientTransformation object that can be used to optimize the model parameters.

Args:

  • training_parameters: A dictionary containing the training parameters.
  • variables: A pytree containing the model parameters.
  • initial_lr: The initial learning rate.

Returns:

  • An optax.GradientTransformation object that can be used to optimize the model parameters.
def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
389def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
390    lr = training_parameters.get("lr", 1.0e-3)
391    init_lr = training_parameters.get("init_lr", lr / 25)
392    final_lr = training_parameters.get("final_lr", lr / 10000)
393
394    #### LEARNING RATE SCHEDULER ####
395    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
396    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
397    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
398
399    adaptive_scheduler = False
400    print("Schedule type:", schedule_type)
401    if schedule_type == "cosine_onecycle":
402        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
403        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
404        linear_warmup = training_parameters.get("linear_warmup", False)
405
406        if linear_warmup:
407            schedule_ = optax.warmup_cosine_decay_schedule(
408                init_value=init_lr,
409                peak_value=lr,
410                warmup_steps=peak_epoch * nbatch_per_epoch,
411                decay_steps=transition_epochs * nbatch_per_epoch,
412                end_value=final_lr,
413            )
414        else:
415            schedule_ = optax.cosine_onecycle_schedule(
416                peak_value=lr,
417                div_factor=lr / init_lr,
418                final_div_factor=init_lr / final_lr,
419                transition_steps=transition_epochs * nbatch_per_epoch,
420                pct_start=peak_epoch / transition_epochs,
421            )
422        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
423
424        def schedule(state, rmse=None):
425            new_state = {**state}
426            lr = schedule_(state["count"])
427            if rmse is None:
428                new_state["count"] += 1
429            new_state["lr"] = lr
430            return lr, new_state
431
432    elif schedule_type == "piecewise_interpolate":
433        schedule_params = training_parameters.get("scheduler_parameters", {})
434        schedule_ = optax.piecewise_interpolate_schedule(
435            **{"init_value": lr, "interpolate_type": "linear", **schedule_params}
436        )
437        sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)}
438
439        def schedule(state, rmse=None):
440            new_state = {**state}
441            lr = schedule_(state["count"])
442            if rmse is None:
443                new_state["count"] += 1
444            new_state["lr"] = lr
445            return lr, new_state
446
447    elif schedule_type == "constant":
448        sch_state = {"count": 0}
449
450        def schedule(state, rmse=None):
451            new_state = {**state}
452            new_state["lr"] = lr
453            if rmse is None:
454                new_state["count"] += 1
455            return lr, new_state
456
457    elif schedule_type == "cosine":
458        assert (
459            "peak_epoch" in training_parameters
460        ), "Sine schedule requires 'peak_epoch' parameter"
461        period = training_parameters["peak_epoch"] * nbatch_per_epoch
462        peak_lr = lr
463        sch_state = {"count": 0}
464
465        def schedule(state, rmse=None):
466            new_state = {**state}
467            istep = state["count"]
468            g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period))
469            lr = peak_lr + (init_lr - peak_lr) * g
470            new_state["lr"] = lr
471            if rmse is None:
472                new_state["count"] += 1
473            return lr, new_state
474
475    elif schedule_type == "reduce_on_plateau":
476        patience = training_parameters.get("patience", 10)
477        factor = training_parameters.get("lr_factor", 0.5)
478        patience_thr = training_parameters.get("patience_thr", 0.0)
479        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
480        adaptive_scheduler = True
481
482        def schedule(state, rmse=None):
483            new_state = {**state}
484            if rmse is None:
485                new_state["count"] += 1
486                return state["lr"], new_state
487            if rmse <= state["best"] * (1.0 + patience_thr):
488                if rmse < state["best"]:
489                    new_state["best"] = rmse
490                new_state["patience"] = 0
491            else:
492                new_state["patience"] += 1
493                if new_state["patience"] >= patience:
494                    new_state["lr"] = state["lr"] * factor
495                    new_state["patience"] = 0
496                    print("Reducing learning rate to", new_state["lr"])
497            return new_state["lr"], new_state
498
499    else:
500        raise ValueError(f"Unknown schedule_type: {schedule_type}")
501
502    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
503    if stochastic_scheduler:
504        schedule_ = schedule
505        rng_key, scheduler_key = jax.random.split(rng_key)
506        sch_state["rng_key"] = scheduler_key
507        sch_state["lr_max"] = lr
508        sch_state["lr_min"] = final_lr
509
510        def schedule(state, rmse=None):
511            new_state = {**state, "lr": state["lr_max"]}
512            if rmse is None:
513                lr_max, new_state = schedule_(new_state, rmse=rmse)
514                lr_min = new_state["lr_min"]
515                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
516                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
517                new_state["lr"] = lr
518                new_state["lr_max"] = lr_max
519
520            return new_state["lr"], new_state
521
522    return schedule, sch_state, schedule_metrics, adaptive_scheduler