发布时间:2025-06-24 20:15:32 作者:北方职教升学中心 阅读量:687
1.2.2 Image and Probability Distribution
RGB图片各通道的值范围为:[0, 255]
我们将各通道的通过(R / 255 , G / 255 , B / 255 R/255,G/255,B/255 R/255,G/255,B/255)归一化到范围:[0, 1]图片单个通道的概率分布(1D Gaussian)
图片两个通道的概率分布(2D Gaussian)
μ = [ μ x 1 , μ x 2 ] = [ μ b l u e , μ g r e e n ] \bf{mu}=[\mu_{x_1},\mu_{x_2}]=[\mu_{blue},\mu_{green}] μ=[μx1,μx2]=[μblue,μgreen]
Σ = [ σ x 1 2 σ x 1 , x 2 σ x 2 , x 1 σ x 2 2 ] = [ σ b l u e 2 σ b l u e , g r e e n σ g r e e n , b l u e σ g r e e n 2 ] \Sigma=\begin{bmatrix} \sigma_{x_1}^2 & \sigma_{x_1,x_2}\\ \sigma_{x_2,x_1} & \sigma_{x_2}^2 \end{bmatrix}=\begin{bmatrix} \sigma_{blue}^2 & \sigma_{blue,green}\\ \sigma_{green,blue} & \sigma_{green}^2 \end{bmatrix} Σ=[σx12σx2,x1σx1,x2σx22]=[σblue2σgreen,blueσblue,greenσgreen2]
图片三个通道的概率分布(3D 正态分布)
μ = [ μ x , μ y , μ z ] = [ μ r e d , μ g r e e n , μ b l u e ] \bf{mu}=[\mu_{x},\mu_{y},\mu_{z}]=[\mu_{red},\mu_{green},\mu_{blue}] μ=[μx,μy,μz]=[μred,μgreen,μblue]
Σ = [ σ x 2 σ x y σ x z σ y x σ y 2 σ y z σ z x σ z σ z 2 ] \Sigma=\begin{bmatrix} \sigma_{x}^2 & \sigma_{xy} & \sigma_{xz}\\ \sigma_{yx} & \sigma_{y}^2 & \sigma_{yz}\\ \sigma_{zx} & \sigma_{z} & \sigma_{z}^2\\ \end{bmatrix} Σ=σx2σyxσzxσxyσy2σzσxzσyzσz2
在Stable Diffusion训练过程中我们要给clear image加噪声,则我们需要在三维标准正态分布中进行随机采样,这样采样得到的tensor shape与图片tensor的shape一致
ϵ ∼ N ( 0 , I ) \epsilon \sim N(0,I) ϵ∼N(0,I)
1.2.3 Maximize ELBO (Maximize Evidence Lower Bound)
我们想要收集大量样本数据,使得这些数据的分布尽可能的接近真实分布(已知的所有图片数据的分布)
通过最大化样本概率(极大化似然)使得样本数据的分布尽可能符合真实分布
第i i i张样本图片的概率分布 p θ ( x i ) p_{theta}(x^i) pθ(xi),将数据集中m m m张照片的分布相乘得到联合概率分布,求该联合分布的极大似然,最终得到一个最优的参数θ = ( μ , σ ) \theta=(\mu,\sigma) θ=(μ,σ)
我们先来简单了解一下VAE的目标函数
we don’t know true distribution p ( x ∣ z ) p(x|z) p(x∣z)but we learn to approximate it by a neural network, we want to learn p ( x ∣ z ) p(x|z) p(x∣z)such that we are able to generate images as close to our training data distribution as possible and for that we try to maximize the log likelihood of the observed data so we have the following formula.
对于Diffusion model的目标函数,我们类比VAE
我们先来看 Reconstruction Term 中的 p θ ( x 0 ∣ x 1 ) p_{theta}(x_0|x_1) pθ(x0∣x1)如何计算?
来自DDPM论文3.3节
接着我们来看看Prior matching Term
KL散度表示两个分布q ( x T ∣ x 0 ) q(x_T|x_0) q(xT∣x0)与p ( x T ) p(x_T) p(xT)之间有多相似,q ( x T ∣ x 0 ) q(x_T|x_0) q(xT∣x0)为前向加噪过程由x 0 x_0 x0得到x T x_T xT,p ( x T ) p(x_T) p(xT)为标准高斯先验,当T够大时,这两个分布别无二致,我们假设这两个分布一致,则KL散度值为0
最后我们来关注 denoising matching term
Diffusion model 训练的目标函数
DDPM中的实验表明剔除系数时训练的效果已足以
Diffusion model 训练的简化目标函数
目前Stable Diffusion的Unet有三种预测方案:
(1)Unet 直接预测 x 0 x_0 x0,但是效果不好(2)Unet 预测要去掉的噪声分布(本次训练使用这种方案)
(3)Unet 预测分数
1.3 Training (from DDPM thesis)
batch size, iteration, and epoch
一个数据集由一个epoch组成,一个数据集训练n遍(n个epoch),也就是说一个周期(epoch)包含了数据集的所有数据
一个epoch由多个batch组成,一个batch由多张image组成
完整训练代码
importos.pathimporttorchimporttorch.nn asnnimporttorch.optim asoptimfromddpm importDDPMSamplerfromdiffusion importUNET,Diffusionimportloggingfromtorch.utils.tensorboard importSummaryWriterfromtqdm importtqdmfrompipeline importget_time_embeddingfromcreate_dataset importtrain_loaderimportlogging'''Algorithm Training1:repeat2: x_0 ~ q(x_0) # sample a batch from a epoch# for epoch for batch for every image tensortrain_loader3: t ~ Uniform({1...T})# sample randomly a t for every image tensor# t: num_inference_step# T: num_training_stept = diffusion.sample_timesteps(images.shape[0]).to(device)4: epsilon ~ N(0,I) # 3d standard normal distribution# noise tensor shape that sample from this distribution,which is same as image tensor shapenoisy_image_tensor = add_noise(t)5: Take gradient descent step on # nabla_{theta} L2(|| epsilon - epsilon_{theta}(noisy image tensor,t,y)||)6: until converged''''''1.Data Preprocessing(1) Loading and Transforming Data: Data is loaded from the dataset and transformed to a suitable format for training. Common transformations include resizing, normalization, and converting to tensors.(2) Creating Data Loaders: Data loaders are created to efficiently load the data in batches, shuffle the training data, and manage parallel processing.2.Model Initialization(1) Define the UNet Model: The UNet architecture is defined, which typically consists of an encoder-decoder structure with skip connections. The encoder captures context while the decoder enables precise localization.(2) Move Model to Device: The model is moved to the appropriate device (CPU or GPU) to leverage hardware acceleration.3.Loss Function and Optimizer(1) Loss Function: The loss function measures the difference between the predicted output and the true output. (2) Optimizer: The optimizer updates the model parameters to minimize the loss. Common optimizers include Adam,SGD,etc.4.Training Loop(1) Set Model to Training Mode: The model is set to training mode using model.train().(2) Iterate Over Data: For each epoch, iterate over batches of data. Forward Pass: Pass input data through the model to get predictions. A random time step t will be selected for each training sample (image) Apply the Gaussian noise (corresponding to t) to each image Convert the time steps to embeddings (vector) Compute Loss: Calculate the loss using the predictions and ground truth. Backward Pass: Perform backpropagation to compute gradients. Update Parameters: Use the optimizer to update model parameters based on the gradients.(3) Monitor Training: Track and print training loss to monitor progress.5.ValidationAfter each epoch, validate the model using a separate validation set to ensure the model is not overfitting and to monitor its generalization performance.6.Checkpoint SavingSave Model Checkpoint: Save the model's state, optimizer state, and any relevant training information after each epoch to allow for resuming training if needed.'''# A PyTorch random number generator.generator =torch.Generator(device='cuda')# Sets the seed for generating random numbers. Returns a torch. Generator object.generator.manual_seed(42)# Initialize the DDPMSampler with the random generatorddpm_sampler =DDPMSampler(generator)diffusion =Diffusion()deftimesteps_to_time_emb(timesteps):time_embeddings =[]fori,timestep inenumerate(timesteps):# (1,320)time_emb_320 =get_time_embedding(timestep).to('cuda')embedding =diffusion.time_embedding.to('cuda')time_embedding =embedding(time_emb_320).squeeze(0)# Ensure shape is (1280)# (1,1280)time_embeddings.append(time_embedding)returntorch.stack(time_embeddings)# Final shape should be (batch_size, 1280)print('Start training now !')deftrain(args):device =args.device # Get the device to run the training onmodel =UNET().to(device)# Initialize the model and move it to the devicemodel.train()optimizer =optim.AdamW(model.parameters(),lr=args.lr)# set up the optimizer with AdamWmse =nn.MSELoss()# Mean Squared Error loss functionlogger =SummaryWriter(os.path.join("runs",args.run_name))len_train =len(train_loader)print('Start into the loop !')forepoch inrange(args.epochs):logging.info(f"Starting epoch {epoch}:")# log the start of the epochprogress_bar =tqdm(train_loader)# progress bar for the dataloaderoptimizer.zero_grad()# Explicitly zero the gradient buffersaccumulation_steps =4# Load all data into a batchforbatch_idx,(images,captions)inenumerate(progress_bar):images =images.to(device)# move images to the device# The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE# and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloaderimages =torch.squeeze(images,dim=1)captions =captions.to(device)# move caption to the devicetext_embeddings =torch.squeeze(captions,dim=1)# squeeze batch_sizetimesteps =ddpm_sampler.sample_timesteps(images.shape[0]).to(device)# Sample random timestepsnoisy_latent_images,noises =ddpm_sampler.add_noise(images,timesteps)# Add noise to the imagestime_embeddings =timesteps_to_time_emb(timesteps)# x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)# caption (batch_size, seq_len, dim) (bs, 77, 768)# t (batch_size, channel) (batch_size, 1280)# (bs,320,H/8,W/8)withtorch.no_grad():last_decoder_noise =model(noisy_latent_images,text_embeddings,time_embeddings)# (bs,4,H/8,W/8)final_output =diffusion.final.to(device)predicted_noise =final_output(last_decoder_noise).to(device)loss =mse(noises,predicted_noise)# Compute the lossloss.backward()# Backpropagate the lossif(batch_idx +1)%accumulation_steps ==0:# Wait for several backward passesoptimizer.step()# Now we can do an optimizer stepoptimizer.zero_grad()# Reset gradients to zeroprogress_bar.set_postfix(MSE=loss.item())# Update the progress bar with the loss# log the loss to TensorBoardlogger.add_scalar("MSE",loss.item(),global_step=epoch *len_train +batch_idx)# Save the model checkpointos.makedirs(os.path.join("models",args.run_name),exist_ok=True)torch.save(model.state_dict(),os.path.join("models",args.run_name,f"stable_diffusion.ckpt"))torch.save(optimizer.state_dict(),os.path.join("models",args.run_name,f"optim.pt"))# Save the optimizer statedeflaunch():importargparse # Import the argparse module for command-line argument parsingparser =argparse.ArgumentParser()# Create an argument parserargs =parser.parse_args()# Parse the command-line arguments# Set the default values for the argumentsargs.run_name =" Condition_Unet"# Name for the run, used for logging and saving modelsargs.epochs =40# Number of epochs to train the modelargs.batch_size =10# Batch size for the dataloaderargs.image_size =256# Size of the imagesargs.device ="cuda"# Device to run the training on ('cuda' for GPU or 'cpu')args.lr =3e-4# Learning rate for the optimizertrain(args)# Call the train function with the parsed argumentsif__name__ =='__main__':launch()# Call the launch function if this script is run as the main program
2.CUDA out of memory
2.1 Reasons
2.1.1 Large Batch Size
Using a batch size that is too large can quickly exhaust GPU memory, especially with large models or high-resolution images.
2.1.2 High Model Complexity
Complex models with many layers and parameters consume more memory. This includes architectures with large fully connected layers, extensive use of skip connections, or multi-headed attention mechanisms.
2.1.3 Large Input Data
High-resolution images or large input tensors consume more memory.
2.1.4 Insufficient Memory Management
Not clearing intermediate variables or not using memory-efficient operations can lead to memory leaks or inefficient memory usage.
2.1.5 Gradients and Optimizer States
Storing gradients and optimizer states, especially for adaptive optimizers like Adam or RMSprop, can be memory-intensive.
2.1.6 Memory Fragmentation
Fragmentation occurs when memory is allocated and deallocated in such a way that it becomes difficult to find contiguous blocks of memory, leading to inefficient memory use.
2.2 Solutions
2.2.1 Reduce Batch Size
Decreasing the batch size is the simplest and most effective way to reduce memory usage.
args.batch_size =5# Example: reduce the batch size
2.2.2 Use Mixed Precision Training
Mixed precision training can reduce memory usage by using 16-bit floats instead of 32-bit floats for certain operations.
以下为gpt修改的关于笔者训练stable diffusion时的代码
fromtorch.cuda.amp importGradScaler,autocastscaler =GradScaler()deftrain(args):device =args.device model =UNET().to(device)model.train()optimizer =optim.AdamW(model.parameters(),lr=args.lr)mse =nn.MSELoss()logger =SummaryWriter(os.path.join("runs",args.run_name))len_train =len(train_loader)forepoch inrange(args.epochs):logging.info(f"Starting epoch {epoch}:")progress_bar =tqdm(train_loader)optimizer.zero_grad()accumulation_steps =4forbatch_idx,(images,captions)inenumerate(progress_bar):images =images.to(device)images =torch.squeeze(images,dim=1)captions =captions.to(device)text_embeddings =torch.squeeze(captions,dim=1)timesteps =ddpm_sampler.sample_timesteps(images.shape[0]).to(device)noisy_latent_images,noises =ddpm_sampler.add_noise(images,timesteps)time_embeddings =timesteps_to_time_emb(timesteps)withautocast():last_decoder_noise =model(noisy_latent_images,text_embeddings,time_embeddings)final_output =diffusion.final.to(device)predicted_noise =final_output(last_decoder_noise).to(device)loss =mse(noises,predicted_noise)scaler.scale(loss).backward()if(batch_idx +1)%accumulation_steps ==0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()progress_bar.set_postfix(MSE=loss.item())logger.add_scalar("MSE",loss.item(),global_step=epoch *len_train +batch_idx)torch.cuda.empty_cache()os.makedirs(os.path.join("models",args.run_name),exist_ok=True)torch.save(model.state_dict(),os.path.join("models",args.run_name,f"stable_diffusion.ckpt"))torch.save(optimizer.state_dict(),os.path.join("models",args.run_name,f"optim.pt"))
2.2.3 Gradient Accumulation
Accumulate gradients over multiple iterations before updating model parameters. This effectively simulates a larger batch size without increasing memory usage.
Accumulating gradients over multiple iterations refers to a technique where you perform forward and backward passes on smaller batches of data and accumulate the gradients over several iterations before updating the model parameters. This approach allows you to simulate a larger batch size without increasing memory usage, which is especially useful when you have limited GPU memory.This method effectively increases the batch size without increasing memory usage, as you don’t need to hold all the data in memory at once.
![]() standard training loop.jpg | ![]() gradient accumulation.jpg |
Key Points
1.Batch Size vs. Mini-Batch Size:
(1) The original batch size is split into smaller mini-batches to fit into GPU memory.
(2) accumulation_steps * mini_batch_size = effective_batch_size.
2.Loss Scaling:
(1) The loss is divided by accumulation_steps to ensure that the gradient magnitudes remain consistent with what they would be if you processed the entire batch at once.
3.Optimizer Step and Gradient Zeroing:
(1) The optimizer step is performed, and gradients are zeroed only after accumulating gradients over several mini-batches.
fromtorch.cuda.amp importGradScaler,autocast# Assuming you have defined your model, optimizer, loss function, and data loadermodel =UNET().to(device)optimizer =optim.AdamW(model.parameters(),lr=args.lr)scaler =GradScaler()mse =nn.MSELoss()accumulation_steps =4# Number of mini-batches to accumulate gradients overforepoch inrange(args.epochs):model.train()optimizer.zero_grad()forbatch_idx,(images,captions)inenumerate(train_loader):images =images.to(device)captions =captions.to(device)text_embeddings =torch.squeeze(captions,dim=1)timesteps =ddpm_sampler.sample_timesteps(images.shape[0]).to(device)noisy_latent_images,noises =ddpm_sampler.add_noise(images,timesteps)time_embeddings =timesteps_to_time_emb(timesteps)withautocast():last_decoder_noise =model(noisy_latent_images,text_embeddings,time_embeddings)final_output =diffusion.final.to(device)predicted_noise =final_output(last_decoder_noise).to(device)loss =mse(noises,predicted_noise)/accumulation_steps scaler.scale(loss).backward()# Accumulate gradients but do not update the weights yetif(batch_idx +1)%accumulation_steps ==0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()# Optional: Save model checkpoint after each epochtorch.save(model.state_dict(),f"model_epoch_{epoch}.pth")
2.2.4 Clear Cache
Manually clear the GPU cache to free up unused memory.
fromtorch.cuda.amp importGradScaler,autocastdeftrain(args):device =args.device model =UNET().to(device)model.train()optimizer =optim.AdamW(model.parameters(),lr=args.lr)scaler =GradScaler()mse =nn.MSELoss()logger =SummaryWriter(os.path.join("runs",args.run_name))len_train =len(train_loader)forepoch inrange(args.epochs):logging.info(f"Starting epoch {epoch}:")progress_bar =tqdm(train_loader)optimizer.zero_grad()accumulation_steps =4forbatch_idx,(images,captions)inenumerate(progress_bar):images =images.to(device)images =torch.squeeze(images,dim=1)captions =captions.to(device)text_embeddings =torch.squeeze(captions,dim=1)timesteps =ddpm_sampler.sample_timesteps(images.shape[0]).to(device)noisy_latent_images,noises =ddpm_sampler.add_noise(images,timesteps)time_embeddings =timesteps_to_time_emb(timesteps)withautocast():last_decoder_noise =model(noisy_latent_images,text_embeddings,time_embeddings)final_output =diffusion.final.to(device)predicted_noise =final_output(last_decoder_noise).to(device)loss =mse(noises,predicted_noise)/accumulation_steps scaler.scale(loss).backward()if(batch_idx +1)%accumulation_steps ==0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()# Clear cache to free up memorytorch.cuda.empty_cache()progress_bar.set_postfix(MSE=loss.item())logger.add_scalar("MSE",loss.item(),global_step=epoch *len_train +batch_idx)# Save model checkpoint after each epochos.makedirs(os.path.join("models",args.run_name),exist_ok=True)torch.save(model.state_dict(),os.path.join("models",args.run_name,f"stable_diffusion.ckpt"))torch.save(optimizer.state_dict(),os.path.join("models",args.run_name,f"optim.pt"))