fennol.training.optimizers
1from typing import ( 2 Callable, 3 Optional, 4 Dict, 5 List, 6 Tuple, 7 Union, 8 Any, 9 NamedTuple, 10 Sequence, 11) 12import optax 13import jax 14import jax.numpy as jnp 15import numpy as np 16import operator 17from flax import traverse_util 18import json 19import re 20 21from optax._src import base 22from optax._src import wrappers 23from optax import tree_utils as otu 24 25 26class AddWeightDiffState(NamedTuple): 27 ref_weights: Any 28 29 30def add_weights_difference( 31 weight_decay: Union[float, jax.Array] = 0.0, 32 mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 33) -> base.GradientTransformation: 34 """weight decay toward initial weights.""" 35 36 def init_fn(params): 37 return AddWeightDiffState(ref_weights=params) 38 39 def update_fn(updates, state, params): 40 if params is None: 41 raise ValueError(base.NO_PARAMS_MSG) 42 updates = jax.tree_util.tree_map( 43 lambda g, p, pref: g + weight_decay * (p - pref), 44 updates, 45 params, 46 state.ref_weights, 47 ) 48 return updates, state 49 50 # If mask is not `None`, apply mask to the gradient transformation. 51 # E.g. it is common to skip weight decay on bias units and batch stats. 52 if mask is not None: 53 return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask) 54 return base.GradientTransformation(init_fn, update_fn) 55 56 57def add_grokfast( 58 alpha: float = 0.9, 59 l: float = 1.0, 60) -> base.GradientTransformation: 61 """Grokfast: amplify slow gradients by exponential moving average.""" 62 63 ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False) 64 65 def init_fn(params): 66 return ema.init(params) 67 68 def update_fn(updates, state, params=None): 69 dupdates, state = ema.update(updates, state, params) 70 # updates = updates + l*dupdates 71 updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates) 72 return updates, state 73 74 return base.GradientTransformation(init_fn, update_fn) 75 76 77class PROFITState(NamedTuple): 78 ref_weights: Any 79 istep: int 80 main_opt_state: Any 81 internal_opt_state: Any 82 83 84def profit( 85 learning_rate: base.ScalarOrSchedule, 86 nsteps_ref: int = 1, 87 main_opt: str = "adam", 88 main_opt_params: Dict[str, Any] = {}, 89 internal_opt: str = "sgd", 90 internal_opt_params: Dict[str, Any] = {}, 91 **kwargs, 92) -> base.GradientTransformation: 93 """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930""" 94 95 main_opt_params = {"learning_rate": learning_rate, **main_opt_params} 96 main_opt = eval( 97 main_opt, 98 {"__builtins__": None}, 99 {**optax.__dict__}, 100 )(**main_opt_params) 101 102 internal_opt_params = {"learning_rate": 0.1, **internal_opt_params} 103 internal_opt = eval( 104 internal_opt, 105 {"__builtins__": None}, 106 {**optax.__dict__}, 107 )(**internal_opt_params) 108 109 def init_fn(params): 110 return PROFITState( 111 ref_weights=params, 112 istep=0, 113 main_opt_state=main_opt.init(params), 114 internal_opt_state=internal_opt.init(params), 115 ) 116 117 def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref): 118 delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref) 119 dot = jax.tree.reduce( 120 operator.add, 121 jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta), 122 ) 123 delta2 = jax.tree.reduce( 124 operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta) 125 ) 126 proj = dot / (delta2 + 1.0e-6) 127 128 gradients = jax.lax.cond( 129 dot >= 0, 130 lambda g, d: g, 131 lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d), 132 gradients, 133 delta, 134 ) 135 updates, main_opt_state = main_opt.update(gradients, main_opt_state, params) 136 updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta) 137 return updates, main_opt_state, internal_opt_state 138 139 def update_internal( 140 gradients, main_opt_state, internal_opt_state, params, params_ref 141 ): 142 updates, internal_opt_state = internal_opt.update( 143 gradients, internal_opt_state, params 144 ) 145 return updates, main_opt_state, internal_opt_state 146 147 def update_fn(gradients, state, params): 148 istep = state.istep % (nsteps_ref + 1) 149 # jax.debug.print("{i} {j}",i=istep,j=state.istep) 150 151 params_ref = jax.lax.cond( 152 istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights 153 ) 154 155 updates, main_opt_state, internal_opt_state = jax.lax.cond( 156 istep == nsteps_ref, 157 update_main, 158 update_internal, 159 gradients, 160 state.main_opt_state, 161 state.internal_opt_state, 162 params, 163 params_ref, 164 ) 165 166 new_state = PROFITState( 167 ref_weights=params_ref, 168 istep=state.istep + 1, 169 main_opt_state=main_opt_state, 170 internal_opt_state=internal_opt_state, 171 ) 172 return updates, new_state 173 174 return base.GradientTransformation(init_fn, update_fn) 175 176 177class MultiEmaState(NamedTuple): 178 """Holds an exponential moving average of past updates.""" 179 180 count: int 181 ema: Sequence[base.Params] 182 183 184def multi_ema( 185 decays: Sequence[float], 186 debias: bool = True, 187 power: Union[bool, Sequence[bool]] = False, 188) -> base.GradientTransformation: 189 """Compute mutliple power moving averages of past updates.""" 190 191 if len(decays) == 0: 192 193 def init_fn(params): 194 return base.EmptyState() 195 196 def update_fn(updates, state, params=None): 197 return [updates], state 198 199 return base.GradientTransformation(init_fn, update_fn) 200 201 if isinstance(power, bool): 202 power = [power] * len(decays) 203 assert len(power) == len(decays), "power and decays must have the same length" 204 205 gammas = [] 206 for decay, p in zip(decays, power): 207 if not p: 208 gammas.append(None) 209 continue 210 t = decay**-2 211 gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max() 212 assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}" 213 gammas.append(gamma) 214 215 def init_fn(params): 216 return MultiEmaState( 217 count=0, 218 ema=[otu.tree_zeros_like(params)] * len(decays), 219 ) 220 221 def update_fn(params, state, other=None): 222 count_inc = state.count + 1 223 updates = [] 224 state_ema = [] 225 for decay, gamma, ema in zip(decays, gammas, state.ema): 226 if gamma is not None: 227 decay = (1.0 - 1.0 / count_inc) ** (gamma + 1) 228 update = new_ema = otu.tree_update_moment(params, ema, decay, order=1) 229 if debias and gamma is None: 230 update = otu.tree_bias_correction(update, decay, count_inc) 231 updates.append(update) 232 state_ema.append(new_ema) 233 return updates, MultiEmaState(count=count_inc, ema=state_ema) 234 235 return base.GradientTransformation(init_fn, update_fn) 236 237 238def get_optimizer( 239 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 240) -> optax.GradientTransformation: 241 """ 242 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 243 244 Args: 245 - training_parameters: A dictionary containing the training parameters. 246 - variables: A pytree containing the model parameters. 247 - initial_lr: The initial learning rate. 248 249 Returns: 250 - An optax.GradientTransformation object that can be used to optimize the model parameters. 251 """ 252 253 default_status = str(training_parameters.get("default_status", "trainable")).lower() 254 assert default_status in [ 255 "trainable", 256 "frozen", 257 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 258 259 # find frozen and trainable parameters 260 frozen = training_parameters.get("frozen", []) 261 trainable = training_parameters.get("trainable", []) 262 263 def training_status(full_path, v): 264 full_path = "/".join(full_path[1:]).lower() 265 # full_path = "/".join(full_path[1:]).lower() 266 status = (default_status, "") 267 for path in frozen: 268 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 269 if re.match(path.lower(), full_path): 270 status = ("frozen", path) 271 for path in trainable: 272 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 273 if re.match(path.lower(), full_path): 274 status = ("trainable", path) 275 return status[0] 276 277 params_partition = traverse_util.path_aware_map(training_status, variables) 278 if len(frozen) > 0 or len(trainable) > 0: 279 print("params partition:") 280 print(json.dumps(params_partition, indent=2, sort_keys=False)) 281 282 ## Gradient preprocessing 283 grad_processing = [] 284 285 # zero nans 286 zero_nans = training_parameters.get("zero_nans", False) 287 if zero_nans: 288 grad_processing.append(optax.zero_nans()) 289 290 # gradient clipping 291 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 292 clip_range = str(training_parameters.get("clipping_range", "global")).lower() 293 if clip_threshold > 0.0: 294 print("gradient norm clipping threshold:", clip_threshold) 295 if clip_range == "local": 296 _clip = optax.clip 297 elif clip_range == "global": 298 _clip = optax.clip_by_global_norm 299 elif clip_range == "block": 300 _clip = optax.clip_by_block_rms 301 else: 302 raise ValueError(f"Invalid gradnorm_clipping_range: {clip_range}. Should be 'local', 'global' or 'block'.") 303 304 grad_processing.append(_clip(clip_threshold)) 305 306 use_grokfast = training_parameters.get("use_grokfast", False) 307 if use_grokfast: 308 print("Using Grokfast") 309 alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9) 310 l_grokfast = training_parameters.get("l_grokfast", 1.0) 311 grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast)) 312 313 # OPTIMIZER 314 optimizer_name = training_parameters.get("optimizer", "adabelief") 315 optimizer = eval( 316 optimizer_name, 317 {"__builtins__": None}, 318 { 319 **optax.__dict__, 320 "profit": profit, 321 }, 322 ) 323 print("Optimizer:", optimizer_name) 324 optimizer_configuration = training_parameters.get("optimizer_config", {}) 325 optimizer_configuration["learning_rate"] = 1.0 326 grad_processing.append(optimizer(**optimizer_configuration)) 327 328 # weight decay 329 weight_decay = training_parameters.get("weight_decay", 0.0) 330 assert weight_decay >= 0.0, "Weight decay must be positive" 331 decay_targets = training_parameters.get("decay_targets", [""]) 332 333 def decay_status(full_path, v): 334 full_path = "/".join(full_path).lower() 335 status = False 336 # print(full_path,re.match(r'^params\/', full_path)) 337 for path in decay_targets: 338 if re.match(r"^params/" + path.lower(), full_path): 339 status = True 340 # if full_path.startswith("params/" + path.lower()): 341 # status = True 342 return status 343 344 decay_mask = traverse_util.path_aware_map(decay_status, variables) 345 if weight_decay > 0.0: 346 print("weight decay:", weight_decay) 347 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 348 grad_processing.append( 349 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 350 ) 351 352 regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0) 353 if regularize_init_weight > 0.0: 354 print( 355 "Regularizing toward initial weights with L2 norm:", regularize_init_weight 356 ) 357 if weight_decay <= 0.0: 358 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 359 360 grad_processing.append( 361 add_weights_difference( 362 weight_decay=-regularize_init_weight, mask=decay_mask 363 ) 364 ) 365 366 if zero_nans: 367 grad_processing.append(optax.zero_nans()) 368 369 # learning rate 370 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 371 ilr = -1 372 373 # gradient clipping 374 clip_threshold = training_parameters.get("adaptive_gradient_clipping", -1.0) 375 if clip_threshold > 0.0: 376 print("Adaptive gradient clipping threshold:", clip_threshold) 377 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 378 ilr -= 1 379 380 ## define optimizer chain 381 optimizer_ = optax.chain( 382 *grad_processing, 383 ) 384 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 385 return optax.multi_transform(partition_optimizer, params_partition), ilr 386 387 388def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters): 389 lr = training_parameters.get("lr", 1.0e-3) 390 init_lr = training_parameters.get("init_lr", lr / 25) 391 final_lr = training_parameters.get("final_lr", lr / 10000) 392 393 #### LEARNING RATE SCHEDULER #### 394 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 395 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 396 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 397 398 adaptive_scheduler = False 399 print("Schedule type:", schedule_type) 400 if schedule_type == "cosine_onecycle": 401 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 402 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 403 linear_warmup = training_parameters.get("linear_warmup", False) 404 405 if linear_warmup: 406 schedule_ = optax.warmup_cosine_decay_schedule( 407 init_value=init_lr, 408 peak_value=lr, 409 warmup_steps=peak_epoch * nbatch_per_epoch, 410 decay_steps=transition_epochs * nbatch_per_epoch, 411 end_value=final_lr, 412 ) 413 else: 414 schedule_ = optax.cosine_onecycle_schedule( 415 peak_value=lr, 416 div_factor=lr / init_lr, 417 final_div_factor=init_lr / final_lr, 418 transition_steps=transition_epochs * nbatch_per_epoch, 419 pct_start=peak_epoch / transition_epochs, 420 ) 421 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 422 423 def schedule(state, rmse=None): 424 new_state = {**state} 425 lr = schedule_(state["count"]) 426 if rmse is None: 427 new_state["count"] += 1 428 new_state["lr"] = lr 429 return lr, new_state 430 431 elif schedule_type == "piecewise_interpolate": 432 schedule_params = training_parameters.get("scheduler_parameters", {}) 433 schedule_ = optax.piecewise_interpolate_schedule( 434 **{"init_value": lr, "interpolate_type": "linear", **schedule_params} 435 ) 436 sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)} 437 438 def schedule(state, rmse=None): 439 new_state = {**state} 440 lr = schedule_(state["count"]) 441 if rmse is None: 442 new_state["count"] += 1 443 new_state["lr"] = lr 444 return lr, new_state 445 446 elif schedule_type == "constant": 447 sch_state = {"count": 0} 448 449 def schedule(state, rmse=None): 450 new_state = {**state} 451 new_state["lr"] = lr 452 if rmse is None: 453 new_state["count"] += 1 454 return lr, new_state 455 456 elif schedule_type == "cosine": 457 assert ( 458 "peak_epoch" in training_parameters 459 ), "Sine schedule requires 'peak_epoch' parameter" 460 period = training_parameters["peak_epoch"] * nbatch_per_epoch 461 peak_lr = lr 462 sch_state = {"count": 0} 463 464 def schedule(state, rmse=None): 465 new_state = {**state} 466 istep = state["count"] 467 g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period)) 468 lr = peak_lr + (init_lr - peak_lr) * g 469 new_state["lr"] = lr 470 if rmse is None: 471 new_state["count"] += 1 472 return lr, new_state 473 474 elif schedule_type == "reduce_on_plateau": 475 patience = training_parameters.get("patience", 10) 476 factor = training_parameters.get("lr_factor", 0.5) 477 patience_thr = training_parameters.get("patience_thr", 0.0) 478 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 479 adaptive_scheduler = True 480 481 def schedule(state, rmse=None): 482 new_state = {**state} 483 if rmse is None: 484 new_state["count"] += 1 485 return state["lr"], new_state 486 if rmse <= state["best"] * (1.0 + patience_thr): 487 if rmse < state["best"]: 488 new_state["best"] = rmse 489 new_state["patience"] = 0 490 else: 491 new_state["patience"] += 1 492 if new_state["patience"] >= patience: 493 new_state["lr"] = state["lr"] * factor 494 new_state["patience"] = 0 495 print("Reducing learning rate to", new_state["lr"]) 496 return new_state["lr"], new_state 497 498 else: 499 raise ValueError(f"Unknown schedule_type: {schedule_type}") 500 501 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 502 if stochastic_scheduler: 503 schedule_ = schedule 504 rng_key, scheduler_key = jax.random.split(rng_key) 505 sch_state["rng_key"] = scheduler_key 506 sch_state["lr_max"] = lr 507 sch_state["lr_min"] = final_lr 508 509 def schedule(state, rmse=None): 510 new_state = {**state, "lr": state["lr_max"]} 511 if rmse is None: 512 lr_max, new_state = schedule_(new_state, rmse=rmse) 513 lr_min = new_state["lr_min"] 514 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 515 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 516 new_state["lr"] = lr 517 new_state["lr_max"] = lr_max 518 519 return new_state["lr"], new_state 520 521 return schedule, sch_state, schedule_metrics, adaptive_scheduler
class
AddWeightDiffState(typing.NamedTuple):
AddWeightDiffState(ref_weights,)
def
add_weights_difference( weight_decay: Union[float, jax.Array] = 0.0, mask: Union[Any, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Any], NoneType] = None) -> optax._src.base.GradientTransformation:
31def add_weights_difference( 32 weight_decay: Union[float, jax.Array] = 0.0, 33 mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 34) -> base.GradientTransformation: 35 """weight decay toward initial weights.""" 36 37 def init_fn(params): 38 return AddWeightDiffState(ref_weights=params) 39 40 def update_fn(updates, state, params): 41 if params is None: 42 raise ValueError(base.NO_PARAMS_MSG) 43 updates = jax.tree_util.tree_map( 44 lambda g, p, pref: g + weight_decay * (p - pref), 45 updates, 46 params, 47 state.ref_weights, 48 ) 49 return updates, state 50 51 # If mask is not `None`, apply mask to the gradient transformation. 52 # E.g. it is common to skip weight decay on bias units and batch stats. 53 if mask is not None: 54 return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask) 55 return base.GradientTransformation(init_fn, update_fn)
weight decay toward initial weights.
def
add_grokfast( alpha: float = 0.9, l: float = 1.0) -> optax._src.base.GradientTransformation:
58def add_grokfast( 59 alpha: float = 0.9, 60 l: float = 1.0, 61) -> base.GradientTransformation: 62 """Grokfast: amplify slow gradients by exponential moving average.""" 63 64 ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False) 65 66 def init_fn(params): 67 return ema.init(params) 68 69 def update_fn(updates, state, params=None): 70 dupdates, state = ema.update(updates, state, params) 71 # updates = updates + l*dupdates 72 updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates) 73 return updates, state 74 75 return base.GradientTransformation(init_fn, update_fn)
Grokfast: amplify slow gradients by exponential moving average.
class
PROFITState(typing.NamedTuple):
78class PROFITState(NamedTuple): 79 ref_weights: Any 80 istep: int 81 main_opt_state: Any 82 internal_opt_state: Any
PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)
def
profit( learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]]], nsteps_ref: int = 1, main_opt: str = 'adam', main_opt_params: Dict[str, Any] = {}, internal_opt: str = 'sgd', internal_opt_params: Dict[str, Any] = {}, **kwargs) -> optax._src.base.GradientTransformation:
85def profit( 86 learning_rate: base.ScalarOrSchedule, 87 nsteps_ref: int = 1, 88 main_opt: str = "adam", 89 main_opt_params: Dict[str, Any] = {}, 90 internal_opt: str = "sgd", 91 internal_opt_params: Dict[str, Any] = {}, 92 **kwargs, 93) -> base.GradientTransformation: 94 """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930""" 95 96 main_opt_params = {"learning_rate": learning_rate, **main_opt_params} 97 main_opt = eval( 98 main_opt, 99 {"__builtins__": None}, 100 {**optax.__dict__}, 101 )(**main_opt_params) 102 103 internal_opt_params = {"learning_rate": 0.1, **internal_opt_params} 104 internal_opt = eval( 105 internal_opt, 106 {"__builtins__": None}, 107 {**optax.__dict__}, 108 )(**internal_opt_params) 109 110 def init_fn(params): 111 return PROFITState( 112 ref_weights=params, 113 istep=0, 114 main_opt_state=main_opt.init(params), 115 internal_opt_state=internal_opt.init(params), 116 ) 117 118 def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref): 119 delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref) 120 dot = jax.tree.reduce( 121 operator.add, 122 jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta), 123 ) 124 delta2 = jax.tree.reduce( 125 operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta) 126 ) 127 proj = dot / (delta2 + 1.0e-6) 128 129 gradients = jax.lax.cond( 130 dot >= 0, 131 lambda g, d: g, 132 lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d), 133 gradients, 134 delta, 135 ) 136 updates, main_opt_state = main_opt.update(gradients, main_opt_state, params) 137 updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta) 138 return updates, main_opt_state, internal_opt_state 139 140 def update_internal( 141 gradients, main_opt_state, internal_opt_state, params, params_ref 142 ): 143 updates, internal_opt_state = internal_opt.update( 144 gradients, internal_opt_state, params 145 ) 146 return updates, main_opt_state, internal_opt_state 147 148 def update_fn(gradients, state, params): 149 istep = state.istep % (nsteps_ref + 1) 150 # jax.debug.print("{i} {j}",i=istep,j=state.istep) 151 152 params_ref = jax.lax.cond( 153 istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights 154 ) 155 156 updates, main_opt_state, internal_opt_state = jax.lax.cond( 157 istep == nsteps_ref, 158 update_main, 159 update_internal, 160 gradients, 161 state.main_opt_state, 162 state.internal_opt_state, 163 params, 164 params_ref, 165 ) 166 167 new_state = PROFITState( 168 ref_weights=params_ref, 169 istep=state.istep + 1, 170 main_opt_state=main_opt_state, 171 internal_opt_state=internal_opt_state, 172 ) 173 return updates, new_state 174 175 return base.GradientTransformation(init_fn, update_fn)
PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930
class
MultiEmaState(typing.NamedTuple):
178class MultiEmaState(NamedTuple): 179 """Holds an exponential moving average of past updates.""" 180 181 count: int 182 ema: Sequence[base.Params]
Holds an exponential moving average of past updates.
def
multi_ema( decays: Sequence[float], debias: bool = True, power: Union[bool, Sequence[bool]] = False) -> optax._src.base.GradientTransformation:
185def multi_ema( 186 decays: Sequence[float], 187 debias: bool = True, 188 power: Union[bool, Sequence[bool]] = False, 189) -> base.GradientTransformation: 190 """Compute mutliple power moving averages of past updates.""" 191 192 if len(decays) == 0: 193 194 def init_fn(params): 195 return base.EmptyState() 196 197 def update_fn(updates, state, params=None): 198 return [updates], state 199 200 return base.GradientTransformation(init_fn, update_fn) 201 202 if isinstance(power, bool): 203 power = [power] * len(decays) 204 assert len(power) == len(decays), "power and decays must have the same length" 205 206 gammas = [] 207 for decay, p in zip(decays, power): 208 if not p: 209 gammas.append(None) 210 continue 211 t = decay**-2 212 gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max() 213 assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}" 214 gammas.append(gamma) 215 216 def init_fn(params): 217 return MultiEmaState( 218 count=0, 219 ema=[otu.tree_zeros_like(params)] * len(decays), 220 ) 221 222 def update_fn(params, state, other=None): 223 count_inc = state.count + 1 224 updates = [] 225 state_ema = [] 226 for decay, gamma, ema in zip(decays, gammas, state.ema): 227 if gamma is not None: 228 decay = (1.0 - 1.0 / count_inc) ** (gamma + 1) 229 update = new_ema = otu.tree_update_moment(params, ema, decay, order=1) 230 if debias and gamma is None: 231 update = otu.tree_bias_correction(update, decay, count_inc) 232 updates.append(update) 233 state_ema.append(new_ema) 234 return updates, MultiEmaState(count=count_inc, ema=state_ema) 235 236 return base.GradientTransformation(init_fn, update_fn)
Compute mutliple power moving averages of past updates.
def
get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
239def get_optimizer( 240 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 241) -> optax.GradientTransformation: 242 """ 243 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 244 245 Args: 246 - training_parameters: A dictionary containing the training parameters. 247 - variables: A pytree containing the model parameters. 248 - initial_lr: The initial learning rate. 249 250 Returns: 251 - An optax.GradientTransformation object that can be used to optimize the model parameters. 252 """ 253 254 default_status = str(training_parameters.get("default_status", "trainable")).lower() 255 assert default_status in [ 256 "trainable", 257 "frozen", 258 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 259 260 # find frozen and trainable parameters 261 frozen = training_parameters.get("frozen", []) 262 trainable = training_parameters.get("trainable", []) 263 264 def training_status(full_path, v): 265 full_path = "/".join(full_path[1:]).lower() 266 # full_path = "/".join(full_path[1:]).lower() 267 status = (default_status, "") 268 for path in frozen: 269 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 270 if re.match(path.lower(), full_path): 271 status = ("frozen", path) 272 for path in trainable: 273 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 274 if re.match(path.lower(), full_path): 275 status = ("trainable", path) 276 return status[0] 277 278 params_partition = traverse_util.path_aware_map(training_status, variables) 279 if len(frozen) > 0 or len(trainable) > 0: 280 print("params partition:") 281 print(json.dumps(params_partition, indent=2, sort_keys=False)) 282 283 ## Gradient preprocessing 284 grad_processing = [] 285 286 # zero nans 287 zero_nans = training_parameters.get("zero_nans", False) 288 if zero_nans: 289 grad_processing.append(optax.zero_nans()) 290 291 # gradient clipping 292 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 293 clip_range = str(training_parameters.get("clipping_range", "global")).lower() 294 if clip_threshold > 0.0: 295 print("gradient norm clipping threshold:", clip_threshold) 296 if clip_range == "local": 297 _clip = optax.clip 298 elif clip_range == "global": 299 _clip = optax.clip_by_global_norm 300 elif clip_range == "block": 301 _clip = optax.clip_by_block_rms 302 else: 303 raise ValueError(f"Invalid gradnorm_clipping_range: {clip_range}. Should be 'local', 'global' or 'block'.") 304 305 grad_processing.append(_clip(clip_threshold)) 306 307 use_grokfast = training_parameters.get("use_grokfast", False) 308 if use_grokfast: 309 print("Using Grokfast") 310 alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9) 311 l_grokfast = training_parameters.get("l_grokfast", 1.0) 312 grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast)) 313 314 # OPTIMIZER 315 optimizer_name = training_parameters.get("optimizer", "adabelief") 316 optimizer = eval( 317 optimizer_name, 318 {"__builtins__": None}, 319 { 320 **optax.__dict__, 321 "profit": profit, 322 }, 323 ) 324 print("Optimizer:", optimizer_name) 325 optimizer_configuration = training_parameters.get("optimizer_config", {}) 326 optimizer_configuration["learning_rate"] = 1.0 327 grad_processing.append(optimizer(**optimizer_configuration)) 328 329 # weight decay 330 weight_decay = training_parameters.get("weight_decay", 0.0) 331 assert weight_decay >= 0.0, "Weight decay must be positive" 332 decay_targets = training_parameters.get("decay_targets", [""]) 333 334 def decay_status(full_path, v): 335 full_path = "/".join(full_path).lower() 336 status = False 337 # print(full_path,re.match(r'^params\/', full_path)) 338 for path in decay_targets: 339 if re.match(r"^params/" + path.lower(), full_path): 340 status = True 341 # if full_path.startswith("params/" + path.lower()): 342 # status = True 343 return status 344 345 decay_mask = traverse_util.path_aware_map(decay_status, variables) 346 if weight_decay > 0.0: 347 print("weight decay:", weight_decay) 348 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 349 grad_processing.append( 350 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 351 ) 352 353 regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0) 354 if regularize_init_weight > 0.0: 355 print( 356 "Regularizing toward initial weights with L2 norm:", regularize_init_weight 357 ) 358 if weight_decay <= 0.0: 359 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 360 361 grad_processing.append( 362 add_weights_difference( 363 weight_decay=-regularize_init_weight, mask=decay_mask 364 ) 365 ) 366 367 if zero_nans: 368 grad_processing.append(optax.zero_nans()) 369 370 # learning rate 371 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 372 ilr = -1 373 374 # gradient clipping 375 clip_threshold = training_parameters.get("adaptive_gradient_clipping", -1.0) 376 if clip_threshold > 0.0: 377 print("Adaptive gradient clipping threshold:", clip_threshold) 378 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 379 ilr -= 1 380 381 ## define optimizer chain 382 optimizer_ = optax.chain( 383 *grad_processing, 384 ) 385 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 386 return optax.multi_transform(partition_optimizer, params_partition), ilr
Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
Args:
- training_parameters: A dictionary containing the training parameters.
- variables: A pytree containing the model parameters.
- initial_lr: The initial learning rate.
Returns:
- An optax.GradientTransformation object that can be used to optimize the model parameters.
def
get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
389def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters): 390 lr = training_parameters.get("lr", 1.0e-3) 391 init_lr = training_parameters.get("init_lr", lr / 25) 392 final_lr = training_parameters.get("final_lr", lr / 10000) 393 394 #### LEARNING RATE SCHEDULER #### 395 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 396 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 397 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 398 399 adaptive_scheduler = False 400 print("Schedule type:", schedule_type) 401 if schedule_type == "cosine_onecycle": 402 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 403 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 404 linear_warmup = training_parameters.get("linear_warmup", False) 405 406 if linear_warmup: 407 schedule_ = optax.warmup_cosine_decay_schedule( 408 init_value=init_lr, 409 peak_value=lr, 410 warmup_steps=peak_epoch * nbatch_per_epoch, 411 decay_steps=transition_epochs * nbatch_per_epoch, 412 end_value=final_lr, 413 ) 414 else: 415 schedule_ = optax.cosine_onecycle_schedule( 416 peak_value=lr, 417 div_factor=lr / init_lr, 418 final_div_factor=init_lr / final_lr, 419 transition_steps=transition_epochs * nbatch_per_epoch, 420 pct_start=peak_epoch / transition_epochs, 421 ) 422 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 423 424 def schedule(state, rmse=None): 425 new_state = {**state} 426 lr = schedule_(state["count"]) 427 if rmse is None: 428 new_state["count"] += 1 429 new_state["lr"] = lr 430 return lr, new_state 431 432 elif schedule_type == "piecewise_interpolate": 433 schedule_params = training_parameters.get("scheduler_parameters", {}) 434 schedule_ = optax.piecewise_interpolate_schedule( 435 **{"init_value": lr, "interpolate_type": "linear", **schedule_params} 436 ) 437 sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)} 438 439 def schedule(state, rmse=None): 440 new_state = {**state} 441 lr = schedule_(state["count"]) 442 if rmse is None: 443 new_state["count"] += 1 444 new_state["lr"] = lr 445 return lr, new_state 446 447 elif schedule_type == "constant": 448 sch_state = {"count": 0} 449 450 def schedule(state, rmse=None): 451 new_state = {**state} 452 new_state["lr"] = lr 453 if rmse is None: 454 new_state["count"] += 1 455 return lr, new_state 456 457 elif schedule_type == "cosine": 458 assert ( 459 "peak_epoch" in training_parameters 460 ), "Sine schedule requires 'peak_epoch' parameter" 461 period = training_parameters["peak_epoch"] * nbatch_per_epoch 462 peak_lr = lr 463 sch_state = {"count": 0} 464 465 def schedule(state, rmse=None): 466 new_state = {**state} 467 istep = state["count"] 468 g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period)) 469 lr = peak_lr + (init_lr - peak_lr) * g 470 new_state["lr"] = lr 471 if rmse is None: 472 new_state["count"] += 1 473 return lr, new_state 474 475 elif schedule_type == "reduce_on_plateau": 476 patience = training_parameters.get("patience", 10) 477 factor = training_parameters.get("lr_factor", 0.5) 478 patience_thr = training_parameters.get("patience_thr", 0.0) 479 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 480 adaptive_scheduler = True 481 482 def schedule(state, rmse=None): 483 new_state = {**state} 484 if rmse is None: 485 new_state["count"] += 1 486 return state["lr"], new_state 487 if rmse <= state["best"] * (1.0 + patience_thr): 488 if rmse < state["best"]: 489 new_state["best"] = rmse 490 new_state["patience"] = 0 491 else: 492 new_state["patience"] += 1 493 if new_state["patience"] >= patience: 494 new_state["lr"] = state["lr"] * factor 495 new_state["patience"] = 0 496 print("Reducing learning rate to", new_state["lr"]) 497 return new_state["lr"], new_state 498 499 else: 500 raise ValueError(f"Unknown schedule_type: {schedule_type}") 501 502 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 503 if stochastic_scheduler: 504 schedule_ = schedule 505 rng_key, scheduler_key = jax.random.split(rng_key) 506 sch_state["rng_key"] = scheduler_key 507 sch_state["lr_max"] = lr 508 sch_state["lr_min"] = final_lr 509 510 def schedule(state, rmse=None): 511 new_state = {**state, "lr": state["lr_max"]} 512 if rmse is None: 513 lr_max, new_state = schedule_(new_state, rmse=rmse) 514 lr_min = new_state["lr_min"] 515 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 516 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 517 new_state["lr"] = lr 518 new_state["lr_max"] = lr_max 519 520 return new_state["lr"], new_state 521 522 return schedule, sch_state, schedule_metrics, adaptive_scheduler