Posts
Comments
It seems like it's possible right?
Interestingly a similar tool run by a very smart guy (https://www.creatorml.com/) is apparently about to shut down, so it might not be a financially sustainable thing to try.
I couldn't find a link to the code in the article so in case anyone else wants to try to replicate I think this is it: https://github.com/HugoFry/mats_sae_training_for_ViTs
Just to close the loop on this one, the official huggingface transformers library just uses a for-loop to achieve MoE. I also implemented a version myself using a for loop and it's much more efficient than either vanilla matrix multiplication or that weird batch matmul I write up there for large latent and batch sizes.
wait a minute... could you just...
you don't just literally do this do you?
input = torch.tensor([
[1, 2],
[1, 2],
[1, 2],
]) # (bs, input_dim)
enc_expert_1 = torch.tensor([
[1, 1, 1, 1],
[1, 1, 1, 1],
])
enc_expert_2 = torch.tensor([
[3, 3, 0, 0],
[0, 0, 2, 0],
])
dec_expert_1 = torch.tensor([
[ -1, -1],
[ -1, -1],
[ -1, -1],
[ -1, -1],
])
dec_expert_2 = torch.tensor([
[-10, -10,],
[-10, -10,],
[-10, -10,],
[-10, -10,],
])
def moe(input, enc, dec, nonlinearity):
input = input.unsqueeze(1)
latent = torch.bmm(input, enc)
recon = torch.bmm(nonlinearity(latent, dec))
return recon.squeeze(1), latent.squeeze(1)
# not this! some kind of actual routing algorithm, but you end up with something similar
enc = torch.stack([enc_expert_1, enc_expert_2, enc_expert_1])
dec = torch.stack([dec_expert_1, dec_expert_2, dec_expert_1])
nonlinearity = torch.nn.ReLU()
recons, latent = moe(input, enc, dec, nonlinearity)
This must in some way be horrifically inefficient, right?
Can I ask what you used to implement the MOE routing? Did you use megablocks? I would love to expand on this research but I can't find any straightforward implementation of efficient pytorch MOE routing online.
Do you simply iterate over each max probability expert every time you feed in a batch?
This is dope, thank you for your service. Also, can you hit us with your code on this one? Would love to reproduce.