正好要训练dit的ip_adapetr,flux作为dit架构,成功训练出来第一个ip-adapter,刚好可以当做参考,虽然没有训练脚本,但从推理脚本上,浅推一下他们训练了哪些模块,哪些层。
class LoadFluxIPAdapter:
@classmethod
def INPUT_TYPES(s):
……
def loadmodel(self, ipadatper, clip_vision, provider):
pbar = ProgressBar(6)
device=mm.get_torch_device()
offload_device=mm.unet_offload_device()
pbar.update(1)
ret_ipa = {}
path = os.path.join(dir_xlabs_ipadapters, ipadatper)
ckpt = load_safetensors(path)
pbar.update(1)
path_clip = folder_paths.get_full_path("clip_vision", clip_vision)
try:
clip = FluxClipViT(path_clip)
except:
clip = load_clip_vision(path_clip).model
ret_ipa["clip_vision"] = clip
prefix = "double_blocks."
blocks = {}
proj = {}
for key, value in ckpt.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
pbar.update(1)
improj = ImageProjModel(4096, 768, 4)
improj.load_state_dict(proj)
pbar.update(1)
ret_ipa["ip_adapter_proj_model"] = improj
ret_ipa["double_blocks"] = torch.nn.ModuleList([IPProcessor(4096, 3072) for i in range(19)])
ret_ipa["double_blocks"].load_state_dict(blocks)
pbar.update(1)
return (ret_ipa,)
首先先看load了什么,一个ip_adapter模型,一个clip_image_model。ip_adapter里面包括了嵌入到网络层中的权重,以及一个额外训练的ImageProjModel,这个我们待会要去看看的。上半部分容易看的,后面有个doubel_blocks,用了IPProcessor,根据flux架构应该是针对里面并行注意力层用的,IPProcessor估摸着也是用来做注意力计算的,之后具体看到,再做解释。
class ApplyFluxIPAdapter:
@classmethod
def INPUT_TYPES(s):
……
def applymodel(self, model, ip_adapter_flux, image, strength_model):
debug=False
device=mm.get_torch_device()
offload_device=mm.unet_offload_device()
is_patched = is_model_pathched(model.model)
print(f"Is model already patched? {is_patched}")
mul = 1
if is_patched:
pbar = ProgressBar(5)
else:
mul = 3
count = len(model.model.diffusion_model.double_blocks)
pbar = ProgressBar(5*mul+count)
bi = model.clone()
tyanochky = bi.model
clip = ip_adapter_flux['clip_vision']
if isinstance(clip, FluxClipViT):
#torch.Size([1, 526, 526, 3])
#image = torch.permute(image, (0, ))
#print(image.shape)
#print(image)
clip_device = next(clip.model.parameters()).device
image = torch.clip(image*255, 0.0, 255)
out = clip(image).to(dtype=torch.bfloat16)
neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16)
else:
print("Using old vit clip")
clip_device = next(clip.parameters()).device
pixel_values = clip_preprocess(image.to(clip_device)).float()
out = clip(pixel_values=pixel_values)
neg_out = clip(pixel_values=torch.zeros_like(pixel_values))
neg_out = neg_out[2].to(dtype=torch.bfloat16)
out = out[2].to(dtype=torch.bfloat16)
pbar.update(mul)
if not is_patched:
print("We are patching diffusion model, be patient please")
patches=FluxUpdateModules(tyanochky, pbar)
print("Patched succesfully!")
else:
print("Model already updated")
pbar.update(mul)
#TYANOCHKYBY=16
ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device
ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16)
ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
ipad_blocks = []
for block in ip_adapter_flux['double_blocks']:
ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model)
ipad.load_state_dict(block.state_dict())
ipad.in_hidden_states_neg = ip_neg_pr
ipad.in_hidden_states_pos = ip_projes
ipad.to(dtype=torch.bfloat16)
npp = DoubleStreamMixerProcessor()
npp.add_ipadapter(ipad)
ipad_blocks.append(npp)
pbar.update(mul)
i=0
for name, _ in attn_processors(tyanochky.diffusion_model).items():
attribute = f"diffusion_model.{name}"
#old = copy.copy(get_attr(bi.model, attribute))
if attribute in model.object_patches.keys():
old = copy.copy((model.object_patches[attribute]))
else:
old = None
processor = merge_loras(old, ipad_blocks[i])
processor.to(device, dtype=torch.bfloat16)
bi.add_object_patch(attribute, processor)
i+=1
pbar.update(mul)
return (bi,)
这边ApplyFluxIPAdapter,一个patched不太懂,待会调到is_model_pathched看看,不过后面也有patch的过程,克隆原模型,然后再转个类型。之后就是clip再projection,分别要个pos和neg跟text-embedding差不多了。在之后就是把ip_adapter各层的权重加到原来层上,大致就这样patch这块我们看了后面代码再谈。
下面有个advanced版本,多了这些参数begin_strength, end_strength, smothing_type,在ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model)多了写操作,估计是对权重或里面计算做smooth,这块太远了,之后用到了再看也不迟。
到这里,外层的调用就结束了,看看flux的ip-adapter具体加了什么模块。
class ImageProjModel(torch.nn.Module):
"""Projection Model
https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
唔,这个ImageProjModel倒是和论文里一模一样,毕竟直接从原仓库里拿来的,那大概框架就是一个linear和一个layernorm,比mlp都简单了,简单来说就是把clip出来的token再映射到文本特征的空间(cross_attn的dim),再做一次对齐,映射完后要算多少个tokens,由clip_extra_context_tokens定。
class IPProcessor(nn.Module):
def __init__(self, context_dim, hidden_dim, ip_hidden_states=None, ip_scale=None):
super().__init__()
self.ip_hidden_states = ip_hidden_states
self.ip_scale = ip_scale
self.in_hidden_states_neg = None
self.in_hidden_states_pos = ip_hidden_states
# Ensure context_dim matches the dimension of ip_hidden_states
self.context_dim = context_dim
self.hidden_dim = hidden_dim
# Initialize projections for IP-adapter
self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
def forward(self, img_q, attn):
#img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
# IP-adapter processing
ip_query = img_q # latent sample query
ip_key = self.ip_adapter_double_stream_k_proj(self.ip_hidden_states)
ip_value = self.ip_adapter_double_stream_v_proj(self.ip_hidden_states)
# Reshape projections for multi-head attention
ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads)
ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads)
#img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
# Compute attention between IP projections and the latent query
ip_attention = F.scaled_dot_product_attention(
ip_query,
ip_key,
ip_value,
dropout_p=0.0,
is_causal=False
)
ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads)
return ip_attention*self.ip_scale
非常传统的ip_processor操作,只训练kv,不过这边初始化用的0初始化,可能训练脚本里会让他先继承对应层的kv权重吧。forward里面也是,和普通的attn_processor的forward也就是用text_embedding,这边改为image_embedding了。
flux还有个DoubleStreamMixerProcessor,用来计算mmdit中的双流注意力,框架和opensora的差得挺远,代码比较多就不贴上来了。
class FluxClipViT:
def __init__(self, path_model = None):
if path_model is None:
self.model = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
)
else:
_dir = os.path.dirname(path_model)
write_config(_dir)
config = CLIPVisionConfig.from_pretrained(
os.path.join(_dir, "flux_clip_config.json")
)
self.model = CLIPVisionModelWithProjection.from_pretrained(
path_model,
config=config,
use_safetensors = True,
)
self.image_processor = CLIPImageProcessor()
self.load_device = next(self.model.parameters()).device
def __call__(self, image):
img = self.image_processor(
images=image, return_tensors="pt"
)
img = img.pixel_values
return self.model(img).image_embeds
FluxClip,只是对clipVisionModel做了一点封装,在上面apply的时候也可以看到,就是out的时候少操作了一些而已。
def FluxUpdateModules(flux_model, pbar=None):
save_list = {}
#print((flux_model.diffusion_model.double_blocks))
#for k,v in flux_model.diffusion_model.double_blocks:
#if "double" in k:
count = len(flux_model.diffusion_model.double_blocks)
patches = {}
for i in range(count):
if pbar is not None:
pbar.update(1)
patches[f"double_blocks.{i}"]=CopyDSB(flux_model.diffusion_model.double_blocks[i])
flux_model.diffusion_model.double_blocks[i]=CopyDSB(flux_model.diffusion_model.double_blocks[i])
return patches
上面就是有关patch的部分了,就是一个简单的方法,克隆原来层,方便后面apply等操作,确实不清楚为什么要往patch上靠。
那到这差不多所有新的模块的简单过一遍了,额外的不多,大多是为了符合flux框架而作的小改动。
那么到opensora该怎么改呢?
ip_adapter那块基本可以不用动,ImageProjModel那一块还要再测试一下,查询轮换器看上去复杂一点比较难训,这边linear+ln简单,就不知道效果咋样。
先等现在i2v_adapter小样本数据集能不能拟合吧,能拟合了,工作重心就先放到ip_adapter上,但怎么算训好了,还是得找个标准。
本文地址: Flux-IPAdapter解读
您必须 登录 才能发表评论