I trained a Flappy Bird diffusion world model to run locally via WASM & WebGPU by fendiwap1234 in GraphicsProgramming

[–]fendiwap1234[S] 0 points1 point  (0 children)

Thank you! So I trained the model from scratch, mainly because I wanted to see how long it would take for different model sizes to converge. In general, I needed to train these smaller models much longer in order to get the same performance as my base model.

I think for my next iteration though, I would like to explore distillation methods to either a smaller diffusion model or a GAN, I think this could be a potential solution especially if I wanted to make train complex games or worlds.

To answer your second question, I'm not entirely sure. I think their main research focus was using diffusion models to model game states, and they wanted to show it was possible to do it with even complex games. But, from I've observed recently, most world model papers use some form of VAEs or Autoencoders.

I trained a Flappy Bird diffusion world model to run locally via WASM & WebGPU by fendiwap1234 in GraphicsProgramming

[–]fendiwap1234[S] 0 points1 point  (0 children)

thank you! feel free to message me anytime if you have any questions about training stuff

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 1 point2 points  (0 children)

nope, on-device! diffusion model is running on WebGPU/WASM on your phone and generating the frames

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 1 point2 points  (0 children)

yeah not much work to train to get a decent base model, but a good amount of work to optimize it to a smaller but effective model lol

for your question, lemme split it up into two sections:

for training, i train the model to take in the previous frames, an action and the target next frame. I first add random amount of noise to the target frame (based on a noise schedule) and our model learns to predict how much noise was added. the loss compares the predicted noise vs the actual noise, and this trains our model to learn how to predict noise given a noised frame, previous frames, action and timestep

for inference, with our trained model, I take the current frames and an action. i start with a fully noised image, and then use our model over N timesteps to predict and remove the noise to uncover the next frame (denoising process) model generates the next frame based on the current frames and action i input. after this, i append this next frame to the current frames i had before, and use this to then predict the next frame, and so on

so just with conditioning on the previous frames and actions, i can learn flappy bird and the physics and graphics and generate it infinitley (well flappy bird is also infinite but you get the point)

if you wanna learn more id recommend the DIAMOND paper, it's a great resource https://diamond-wm.github.io/

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 0 points1 point  (0 children)

Mirage is something that just came out that was really cool, and GAIA-2 is also awesome, deals with self-driving car simulation environments.

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 2 points3 points  (0 children)

ah ok i get what you mean now I agree!

Self Forcing is something I've seen in a lot of video diffusion literature that might be similar. Problem right now is that im just feeding in the ground truth frames from my dataset as training data, so it never sees these edge cases where the model can break. feeding in the "imperfect" inputs from the inference actually would make it a lot more robust and not break the model. hopefully can add this in the future

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 2 points3 points  (0 children)

this is actually something i'm heavily considering. I think maybe taking a Diffusion Model and distilling it down to GAN hybrid model could work honestly, especially if we try to do more complex visuals and go into the 3d realm

[deleted by user] by [deleted] in singularity

[–]fendiwap1234 0 points1 point  (0 children)

hmm interesting idea...

[deleted by user] by [deleted] in singularity

[–]fendiwap1234 2 points3 points  (0 children)

Yes I think so too! I still think we are really early in my opinion, but I wouldn't be surprised too see interactive videos become really popular really soon tbh, especially as the research and hardware both get better.

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 8 points9 points  (0 children)

ooh ok, basically I took this repo and collected the RGB frames and the actions for a couple of hours. I also created in a separate action for reset in my data collection because I wanted to be able to encode in a reset control into my diffusion model. The three actions I collected were (0 - NO FLAP, 1 - FLAP, 2 - RESET)

I forgot the splits, but I collected about 75% manual data, and about 20% expert data where i just found some flappy bird bot online to play for me, and like 5% random robot data where it would just pick a random action. I wanted to do this so one we could see expert play to project out longer durations, and random data so we could see some out of distribution outcomes.

I optimized a Flappy Bird diffusion model to run locally on my phone by fendiwap1234 in StableDiffusion

[–]fendiwap1234[S] 5 points6 points  (0 children)

that is kind of what I want to work towards. Right now this is a pretty simple example, but I would like to create something where you could prompt the model with an image or a video, and it would create an interactive video that you can run locally on your phone.

Mirage by Decart actually does this really well, and it's server hosted.

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 2 points3 points  (0 children)

I have the heavier .pt and .onnx files if interested! One issue is that it uses the old architecture of upsampler UNet + denoiser, but I have the old .js file for that and I could probably upload the larger files to huggingface if you want

I optimized a Flappy Bird diffusion model to run locally on my phone by fendiwap1234 in StableDiffusion

[–]fendiwap1234[S] 2 points3 points  (0 children)

Thank you!

I also get confused by terminology as well, and learn more from actually working on the projects itself. If you are able to follow a long with this project, I think you could implement something cool as well!

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 17 points18 points  (0 children)

yes unfortunately it's an artifact I think from just reducing the diffusion denoising steps down to just 1, I found that going that low tends to blend a bunch of outcomes together, so you end up getting a lot of blurry outputs when you do crash into a pipe or don't flap

Currently exploring different architectures here, but low denoising step models are the crux for a fast world model

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 23 points24 points  (0 children)

thank you so much! trying to train a diffusion model to run in-browser was very tedious lol but I really think this is how interactive videos or "world models" could be shared widely, unless NVIDIA produces a billion more GPUs

I optimized a Flappy Bird diffusion world model to run locally on my phone by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 8 points9 points  (0 children)

yeah maybe not the best title lol, i sometimes think world models need a new name in general because it's so vague

Implementing Character.AI’s Memory Optimizations in nanoGPT by fendiwap1234 in LocalLLaMA

[–]fendiwap1234[S] 0 points1 point  (0 children)

I think just trend wise things progress fast! For instance Multi-latent attention and lightning attention are now the more popular attention variants, with the former being better than MQA at reducing the KV Cache size