leonorv's picture
Upload folder using huggingface_hub
d4d2234
"""
featurizers.py
==============
Utility classes for defining *invertible* feature spaces on top of a model’s
hidden-state tensors, together with intervention helpers that operate inside
those spaces.
Key ideas
---------
* **Featurizer** – a lightweight wrapper holding:
• a forward `featurizer` module that maps a tensor **x → (f, error)**
where *error* is the reconstruction residual (useful for lossy
featurizers such as sparse auto-encoders);
• an `inverse_featurizer` that re-assembles the original space
**(f, error) → x̂**.
* **Interventions** – three higher-order factory functions build PyVENE
interventions that work in the featurized space:
- *interchange*
- *collect*
- *mask* (differential binary masking)
All public classes / functions below carry PEP-257-style doc-strings.
"""
from typing import Optional, Tuple
import torch
import pyvene as pv
# --------------------------------------------------------------------------- #
# Basic identity featurizers #
# --------------------------------------------------------------------------- #
class IdentityFeaturizerModule(torch.nn.Module):
"""A no-op featurizer: *x → (x, None)*."""
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
return x, None
class IdentityInverseFeaturizerModule(torch.nn.Module):
"""Inverse of :class:`IdentityFeaturizerModule`."""
def forward(self, x: torch.Tensor, error: None) -> torch.Tensor: # noqa: D401
return x
# --------------------------------------------------------------------------- #
# High-level Featurizer wrapper #
# --------------------------------------------------------------------------- #
class Featurizer:
"""Container object holding paired featurizer and inverse modules.
Parameters
----------
featurizer :
A `torch.nn.Module` mapping **x → (features, error)**.
inverse_featurizer :
A `torch.nn.Module` mapping **(features, error) → x̂**.
n_features :
Dimensionality of the feature space. **Required** when you intend to
build a *mask* intervention; optional otherwise.
id :
Human-readable identifier used by `__str__` methods of the generated
interventions.
"""
# --------------------------------------------------------------------- #
# Construction / public accessors #
# --------------------------------------------------------------------- #
def __init__(
self,
featurizer: torch.nn.Module = IdentityFeaturizerModule(),
inverse_featurizer: torch.nn.Module = IdentityInverseFeaturizerModule(),
*,
n_features: Optional[int] = None,
id: str = "null",
):
self.featurizer = featurizer
self.inverse_featurizer = inverse_featurizer
self.n_features = n_features
self.id = id
# -------------------- Intervention builders -------------------------- #
def get_interchange_intervention(self):
if not hasattr(self, "_interchange_intervention"):
self._interchange_intervention = build_feature_interchange_intervention(
self.featurizer, self.inverse_featurizer, self.id
)
return self._interchange_intervention
def get_collect_intervention(self):
if not hasattr(self, "_collect_intervention"):
self._collect_intervention = build_feature_collect_intervention(
self.featurizer, self.id
)
return self._collect_intervention
def get_mask_intervention(self):
if self.n_features is None:
raise ValueError(
"`n_features` must be provided on the Featurizer "
"to construct a mask intervention."
)
if not hasattr(self, "_mask_intervention"):
self._mask_intervention = build_feature_mask_intervention(
self.featurizer,
self.inverse_featurizer,
self.n_features,
self.id,
)
return self._mask_intervention
# ------------------------- Convenience I/O --------------------------- #
def featurize(self, x: torch.Tensor):
return self.featurizer(x)
def inverse_featurize(self, x: torch.Tensor, error):
return self.inverse_featurizer(x, error)
# --------------------------------------------------------------------- #
# (De)serialisation helpers #
# --------------------------------------------------------------------- #
def save_modules(self, path: str) -> Tuple[str, str]:
"""Serialise featurizer & inverse to `<path>_{featurizer, inverse}`.
Notes
-----
* **SAE featurizers** are *not* serialisable: a
:class:`NotImplementedError` is raised.
* Existing files will be *silently overwritten*.
"""
featurizer_class = self.featurizer.__class__.__name__
if featurizer_class == "SAEFeaturizerModule":
#SAE featurizers are to be loaded from sae_lens
return None, None
inverse_featurizer_class = self.inverse_featurizer.__class__.__name__
# Extra config needed for Subspace featurizers
additional_config = {}
if featurizer_class == "SubspaceFeaturizerModule":
additional_config["rotation_matrix"] = (
self.featurizer.rotate.weight.detach().clone()
)
additional_config["requires_grad"] = (
self.featurizer.rotate.weight.requires_grad
)
model_info = {
"featurizer_class": featurizer_class,
"inverse_featurizer_class": inverse_featurizer_class,
"n_features": self.n_features,
"additional_config": additional_config,
}
torch.save(
{"model_info": model_info, "state_dict": self.featurizer.state_dict()},
f"{path}_featurizer",
)
torch.save(
{
"model_info": model_info,
"state_dict": self.inverse_featurizer.state_dict(),
},
f"{path}_inverse_featurizer",
)
return f"{path}_featurizer", f"{path}_inverse_featurizer"
@classmethod
def load_modules(cls, path: str) -> "Featurizer":
"""Inverse of :meth:`save_modules`.
Returns
-------
Featurizer
A *new* instance with reconstructed modules and metadata.
"""
featurizer_data = torch.load(f"{path}_featurizer")
inverse_data = torch.load(f"{path}_inverse_featurizer")
model_info = featurizer_data["model_info"]
featurizer_class = model_info["featurizer_class"]
if featurizer_class == "SubspaceFeaturizerModule":
rot = model_info["additional_config"]["rotation_matrix"]
requires_grad = model_info["additional_config"]["requires_grad"]
# Re-build a parametrised orthogonal layer with identical shape.
in_dim, out_dim = rot.shape
rotate_layer = pv.models.layers.LowRankRotateLayer(
in_dim, out_dim, init_orth=False
)
rotate_layer.weight.data.copy_(rot)
rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
rotate_layer.requires_grad_(requires_grad)
featurizer = SubspaceFeaturizerModule(rotate_layer)
inverse = SubspaceInverseFeaturizerModule(rotate_layer)
# Sanity-check weight shape
assert (
featurizer.rotate.weight.shape == rot.shape
), "Rotation-matrix shape mismatch after deserialisation."
elif featurizer_class == "IdentityFeaturizerModule":
featurizer = IdentityFeaturizerModule()
inverse = IdentityInverseFeaturizerModule()
else:
raise ValueError(f"Unknown featurizer class '{featurizer_class}'.")
featurizer.load_state_dict(featurizer_data["state_dict"])
inverse.load_state_dict(inverse_data["state_dict"])
return cls(
featurizer,
inverse,
n_features=model_info["n_features"],
id=model_info.get("featurizer_id", "loaded"),
)
# --------------------------------------------------------------------------- #
# Intervention factory helpers #
# --------------------------------------------------------------------------- #
def build_feature_interchange_intervention(
featurizer: torch.nn.Module,
inverse_featurizer: torch.nn.Module,
featurizer_id: str,
):
"""Return a class implementing PyVENE’s TrainableIntervention."""
class FeatureInterchangeIntervention(
pv.TrainableIntervention, pv.DistributedRepresentationIntervention
):
"""Swap features between *base* and *source* in the featurized space."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._featurizer = featurizer
self._inverse = inverse_featurizer
def forward(self, base, source, subspaces=None):
f_base, base_err = self._featurizer(base)
f_src, _ = self._featurizer(source)
if subspaces is None or _subspace_is_all_none(subspaces):
f_out = f_src
else:
f_out = pv.models.intervention_utils._do_intervention_by_swap(
f_base,
f_src,
"interchange",
self.interchange_dim,
subspaces,
subspace_partition=self.subspace_partition,
use_fast=self.use_fast,
)
return self._inverse(f_out, base_err).to(base.dtype)
def __str__(self): # noqa: D401
return f"FeatureInterchangeIntervention(id={featurizer_id})"
return FeatureInterchangeIntervention
def build_feature_collect_intervention(
featurizer: torch.nn.Module, featurizer_id: str
):
"""Return a `CollectIntervention` operating in feature space."""
class FeatureCollectIntervention(pv.CollectIntervention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._featurizer = featurizer
def forward(self, base, source=None, subspaces=None):
f_base, _ = self._featurizer(base)
return pv.models.intervention_utils._do_intervention_by_swap(
f_base,
source,
"collect",
self.interchange_dim,
subspaces,
subspace_partition=self.subspace_partition,
use_fast=self.use_fast,
)
def __str__(self): # noqa: D401
return f"FeatureCollectIntervention(id={featurizer_id})"
return FeatureCollectIntervention
def build_feature_mask_intervention(
featurizer: torch.nn.Module,
inverse_featurizer: torch.nn.Module,
n_features: int,
featurizer_id: str,
):
"""Return a trainable mask intervention."""
class FeatureMaskIntervention(pv.TrainableIntervention):
"""Differential-binary masking in the featurized space."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._featurizer = featurizer
self._inverse = inverse_featurizer
# Learnable parameters
self.mask = torch.nn.Parameter(torch.zeros(n_features), requires_grad=True)
self.temperature: Optional[torch.Tensor] = None # must be set by user
# -------------------- API helpers -------------------- #
def get_temperature(self) -> torch.Tensor:
if self.temperature is None:
raise ValueError("Temperature has not been set.")
return self.temperature
def set_temperature(self, temp: float | torch.Tensor):
self.temperature = (
torch.as_tensor(temp, dtype=self.mask.dtype).to(self.mask.device)
)
def _nonlinear_transform(self, f: torch.Tensor) -> torch.Tensor:
# You can swap this for a real MLP if desired
return torch.tanh(f)
# ------------------------- forward ------------------- #
def forward(self, base, source, subspaces=None):
if self.temperature is None:
raise ValueError("Cannot run forward without a temperature.")
f_base, base_err = self._featurizer(base)
f_src, _ = self._featurizer(source)
# Align devices / dtypes
mask = self.mask.to(f_base.device)
temp = self.temperature.to(f_base.device)
f_base = f_base.to(mask.dtype)
f_src = f_src.to(mask.dtype)
if self.training:
gate = torch.sigmoid(mask / temp)
else:
gate = (torch.sigmoid(mask) > 0.5).float()
f_out = (1.0 - gate) * f_base + gate * f_src
# === Apply nonlinearity during training only ===
# if self.training:
# f_out = self._nonlinear_transform(f_out)
return self._inverse(f_out.to(base.dtype), base_err).to(base.dtype)
# ---------------- Sparsity regulariser --------------- #
def get_sparsity_loss(self) -> torch.Tensor:
if self.temperature is None:
raise ValueError("Temperature has not been set.")
gate = torch.sigmoid(self.mask / self.temperature)
return torch.norm(gate, p=1)
def __str__(self): # noqa: D401
return f"FeatureMaskIntervention(id={featurizer_id})"
return FeatureMaskIntervention
# --------------------------------------------------------------------------- #
# Concrete featurizer implementations #
# --------------------------------------------------------------------------- #
class SubspaceFeaturizerModule(torch.nn.Module):
"""Linear projector onto an orthogonal *rotation* sub-space."""
def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
super().__init__()
self.rotate = rotate_layer
def forward(self, x: torch.Tensor):
r = self.rotate.weight.T # (out, in)ᵀ
f = x.to(r.dtype) @ r.T
error = x - (f @ r).to(x.dtype)
return f, error
class SubspaceInverseFeaturizerModule(torch.nn.Module):
"""Inverse of :class:`SubspaceFeaturizerModule`."""
def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
super().__init__()
self.rotate = rotate_layer
def forward(self, f, error):
r = self.rotate.weight.T
return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
class SubspaceFeaturizer(Featurizer):
"""Orthogonal linear sub-space featurizer."""
def __init__(
self,
*,
shape: Tuple[int, int] | None = None,
rotation_subspace: torch.Tensor | None = None,
trainable: bool = True,
id: str = "subspace",
):
assert (
shape is not None or rotation_subspace is not None
), "Provide either `shape` or `rotation_subspace`."
if shape is not None:
rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
else:
shape = rotation_subspace.shape
rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
rotate.weight.data.copy_(rotation_subspace)
rotate = torch.nn.utils.parametrizations.orthogonal(rotate)
rotate.requires_grad_(trainable)
super().__init__(
SubspaceFeaturizerModule(rotate),
SubspaceInverseFeaturizerModule(rotate),
n_features=rotate.weight.shape[1],
id=id,
)
class SAEFeaturizerModule(torch.nn.Module):
"""Wrapper around a *Sparse Autoencoder*’s encode() / decode() pair."""
def __init__(self, sae):
super().__init__()
self.sae = sae
def forward(self, x):
features = self.sae.encode(x.to(self.sae.dtype))
error = x - self.sae.decode(features).to(x.dtype)
return features.to(x.dtype), error
class SAEInverseFeaturizerModule(torch.nn.Module):
"""Inverse for :class:`SAEFeaturizerModule`."""
def __init__(self, sae):
super().__init__()
self.sae = sae
def forward(self, features, error):
return (
self.sae.decode(features.to(self.sae.dtype)).to(features.dtype)
+ error.to(features.dtype)
)
class SAEFeaturizer(Featurizer):
"""Featurizer backed by a pre-trained sparse auto-encoder.
Notes
-----
Serialisation is *disabled* for SAE featurizers – saving will raise
``NotImplementedError``.
"""
def __init__(self, sae, *, trainable: bool = False):
sae.requires_grad_(trainable)
super().__init__(
SAEFeaturizerModule(sae),
SAEInverseFeaturizerModule(sae),
n_features=sae.cfg.to_dict()["d_sae"],
id="sae",
)
# --------------------------------------------------------------------------- #
# Utility helpers #
# --------------------------------------------------------------------------- #
def _subspace_is_all_none(subspaces) -> bool:
"""Return ``True`` if *every* element of *subspaces* is ``None``."""
return subspaces is None or all(
inner is None or all(elem is None for elem in inner) for inner in subspaces
)