Llama3 like weight init#435
Conversation
| match_count += 1 | ||
| hits[weight_regex] += 1 | ||
| if match_count == 0: | ||
| logger.warning(f"Parameter {parameter_name} did not match any regex for initialization") |
There was a problem hiding this comment.
should we add a flag which turns this into an error?
There was a problem hiding this comment.
Since the norms are initialized within the model factory via reset_parametersthis would always throw an error.
| b=2, | ||
| ), | ||
| # final attention projection in attention block | ||
| r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial( |
There was a problem hiding this comment.
This corresponds to following right ?, but in there you can see for out projection its std=init_std , which can be intialized differently and defaults to depth_init , because here we pass weight_init_std , which default to depth_init in titan here. If we dont want depth init then it matches scaled out_projections logic when depth_init is False for titan
There was a problem hiding this comment.
I implemented depth_init to be fully compliant
| def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None: | ||
| super().__init__() | ||
|
|
||
| self.regex_to_init = { |
There was a problem hiding this comment.
we also need regex patterns for attention_norm, ffn_norm, and the final lm_head_normnai ?. Something like
r"transformer\.h\.\d+\.(attention_norm|ffn_norm)\.weight": nn.init.ones_,
r"transformer\.lm_head_norm\.weight": nn.init.ones_,There was a problem hiding this comment.
we already call this here.
and due to recursion we also call it for the RMSNorm.
http://github.com/pytorch/pytorch/blob/65762ca85745d786ab6b20e9cb060242b51e872d/torch/nn/modules/normalization.py#L407
| if re.fullmatch(weight_regex, parameter_name): | ||
| init_fn, arg_dict = regex_to_init[weight_regex] | ||
| if arg_dict["std"] is not None and callable(arg_dict["std"]): | ||
| if not depth_init: |
There was a problem hiding this comment.
Isnt this dead code now ? , std becomes a callable only when depth_init is True right ?, so this check is not needed
There was a problem hiding this comment.
I added it as a safety check but you're right it's kinda redundant. I'll remove it!
What does this PR do?
This PR ..
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)