正好要训练dit的ip_adapetr,flux作为dit架构,成功训练出来第一个ip-adapter,刚好可以当做参考,虽然没有训练脚本,但从推理脚本上,浅推一下他们训练了哪些模块,哪些层。

class LoadFluxIPAdapter:
    @classmethod
    def INPUT_TYPES(s):
        ……

    def loadmodel(self, ipadatper, clip_vision, provider):
        pbar = ProgressBar(6)
        device=mm.get_torch_device()
        offload_device=mm.unet_offload_device()
        pbar.update(1)
        ret_ipa = {}
        path = os.path.join(dir_xlabs_ipadapters, ipadatper)
        ckpt = load_safetensors(path)
        pbar.update(1)
        path_clip = folder_paths.get_full_path("clip_vision", clip_vision)
        
        try: 
            clip = FluxClipViT(path_clip)
        except:
            clip = load_clip_vision(path_clip).model
        
        ret_ipa["clip_vision"] = clip
        prefix = "double_blocks."
        blocks = {}
        proj = {}
        for key, value in ckpt.items():
            if key.startswith(prefix):
                blocks[key[len(prefix):].replace('.processor.', '.')] = value
            if key.startswith("ip_adapter_proj_model"):
                proj[key[len("ip_adapter_proj_model."):]] = value
        pbar.update(1)
        improj = ImageProjModel(4096, 768, 4)
        improj.load_state_dict(proj)
        pbar.update(1)
        ret_ipa["ip_adapter_proj_model"] = improj

        ret_ipa["double_blocks"] = torch.nn.ModuleList([IPProcessor(4096, 3072) for i in range(19)])
        ret_ipa["double_blocks"].load_state_dict(blocks)
        pbar.update(1)
        return (ret_ipa,)

首先先看load了什么,一个ip_adapter模型,一个clip_image_model。ip_adapter里面包括了嵌入到网络层中的权重,以及一个额外训练的ImageProjModel,这个我们待会要去看看的。上半部分容易看的,后面有个doubel_blocks,用了IPProcessor,根据flux架构应该是针对里面并行注意力层用的,IPProcessor估摸着也是用来做注意力计算的,之后具体看到,再做解释。

class ApplyFluxIPAdapter:
    @classmethod
    def INPUT_TYPES(s):
        ……

    def applymodel(self, model, ip_adapter_flux, image, strength_model):
        debug=False
        device=mm.get_torch_device()
        offload_device=mm.unet_offload_device()
        is_patched = is_model_pathched(model.model)

        print(f"Is model already patched? {is_patched}")
        mul = 1
        if is_patched:
            pbar = ProgressBar(5)
        else:
            mul = 3
            count = len(model.model.diffusion_model.double_blocks)
            pbar = ProgressBar(5*mul+count)

        bi = model.clone()
        tyanochky = bi.model

        clip = ip_adapter_flux['clip_vision']
        
        if isinstance(clip, FluxClipViT):
            #torch.Size([1, 526, 526, 3])
            #image = torch.permute(image, (0, ))
            #print(image.shape)
            #print(image)
            clip_device = next(clip.model.parameters()).device
            image = torch.clip(image*255, 0.0, 255)
            out = clip(image).to(dtype=torch.bfloat16)
            neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16)
        else:
            print("Using old vit clip")
            clip_device = next(clip.parameters()).device
            pixel_values = clip_preprocess(image.to(clip_device)).float()
            out = clip(pixel_values=pixel_values)
            neg_out = clip(pixel_values=torch.zeros_like(pixel_values))    
            neg_out = neg_out[2].to(dtype=torch.bfloat16)
            out = out[2].to(dtype=torch.bfloat16)
        
        pbar.update(mul)
        if not is_patched:
            print("We are patching diffusion model, be patient please")
            patches=FluxUpdateModules(tyanochky, pbar)
            print("Patched succesfully!")
        else:
            print("Model already updated")
        pbar.update(mul)

        #TYANOCHKYBY=16
        ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device
        ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16)
        ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)
        ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16)

        ipad_blocks = []
        for block in ip_adapter_flux['double_blocks']:
            ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model)
            ipad.load_state_dict(block.state_dict())
            ipad.in_hidden_states_neg = ip_neg_pr
            ipad.in_hidden_states_pos = ip_projes
            ipad.to(dtype=torch.bfloat16)
            npp = DoubleStreamMixerProcessor()
            npp.add_ipadapter(ipad)
            ipad_blocks.append(npp)
        pbar.update(mul)
        i=0
        for name, _ in attn_processors(tyanochky.diffusion_model).items():
            attribute = f"diffusion_model.{name}"
            #old = copy.copy(get_attr(bi.model, attribute))
            if attribute in model.object_patches.keys():
                old = copy.copy((model.object_patches[attribute]))
            else:
                old = None
            processor = merge_loras(old, ipad_blocks[i])
            processor.to(device, dtype=torch.bfloat16)
            bi.add_object_patch(attribute, processor)
            i+=1
        pbar.update(mul)
        return (bi,)

这边ApplyFluxIPAdapter,一个patched不太懂,待会调到is_model_pathched看看,不过后面也有patch的过程,克隆原模型,然后再转个类型。之后就是clip再projection,分别要个pos和neg跟text-embedding差不多了。在之后就是把ip_adapter各层的权重加到原来层上,大致就这样patch这块我们看了后面代码再谈。

下面有个advanced版本,多了这些参数begin_strength, end_strength, smothing_type,在ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model)多了写操作,估计是对权重或里面计算做smooth,这块太远了,之后用到了再看也不迟。

到这里,外层的调用就结束了,看看flux的ip-adapter具体加了什么模块。

class ImageProjModel(torch.nn.Module):
    """Projection Model
    https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
    """

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

唔,这个ImageProjModel倒是和论文里一模一样,毕竟直接从原仓库里拿来的,那大概框架就是一个linear和一个layernorm,比mlp都简单了,简单来说就是把clip出来的token再映射到文本特征的空间(cross_attn的dim),再做一次对齐,映射完后要算多少个tokens,由clip_extra_context_tokens定。


class IPProcessor(nn.Module):
    def __init__(self, context_dim, hidden_dim, ip_hidden_states=None, ip_scale=None):
        super().__init__()
        self.ip_hidden_states = ip_hidden_states
        self.ip_scale = ip_scale
        self.in_hidden_states_neg = None
        self.in_hidden_states_pos = ip_hidden_states
        # Ensure context_dim matches the dimension of ip_hidden_states
        self.context_dim = context_dim
        self.hidden_dim = hidden_dim

        # Initialize projections for IP-adapter
        self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
        self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
        
        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
        nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
        
        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
        nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)

    def forward(self, img_q, attn):
        #img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
        # IP-adapter processing
        ip_query = img_q  # latent sample query
        ip_key = self.ip_adapter_double_stream_k_proj(self.ip_hidden_states)
        ip_value = self.ip_adapter_double_stream_v_proj(self.ip_hidden_states)
        
        # Reshape projections for multi-head attention
        ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads)
        ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads)
        #img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
        # Compute attention between IP projections and the latent query
        ip_attention = F.scaled_dot_product_attention(
            ip_query, 
            ip_key, 
            ip_value, 
            dropout_p=0.0, 
            is_causal=False
        )
        ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads)
        return ip_attention*self.ip_scale

非常传统的ip_processor操作,只训练kv,不过这边初始化用的0初始化,可能训练脚本里会让他先继承对应层的kv权重吧。forward里面也是,和普通的attn_processor的forward也就是用text_embedding,这边改为image_embedding了。

flux还有个DoubleStreamMixerProcessor,用来计算mmdit中的双流注意力,框架和opensora的差得挺远,代码比较多就不贴上来了。

class FluxClipViT:
    def __init__(self, path_model = None):
        if path_model is None:
            self.model = CLIPVisionModelWithProjection.from_pretrained(
                "openai/clip-vit-large-patch14"
            )
            
        else:
            _dir = os.path.dirname(path_model)
            write_config(_dir)
            config = CLIPVisionConfig.from_pretrained(
                os.path.join(_dir, "flux_clip_config.json")
            )
            self.model = CLIPVisionModelWithProjection.from_pretrained(
                path_model,
                config=config,
                use_safetensors = True,
            )
        self.image_processor = CLIPImageProcessor()
        self.load_device = next(self.model.parameters()).device

    def __call__(self, image):
        img = self.image_processor(
            images=image, return_tensors="pt"
            )
        img = img.pixel_values
        return self.model(img).image_embeds

FluxClip,只是对clipVisionModel做了一点封装,在上面apply的时候也可以看到,就是out的时候少操作了一些而已。

def FluxUpdateModules(flux_model, pbar=None):
    save_list = {}
    #print((flux_model.diffusion_model.double_blocks))
    #for k,v in flux_model.diffusion_model.double_blocks:
        #if "double" in k:
    count = len(flux_model.diffusion_model.double_blocks)
    patches = {}

    for i in range(count):
        if pbar is not None:
            pbar.update(1)
        patches[f"double_blocks.{i}"]=CopyDSB(flux_model.diffusion_model.double_blocks[i])
        flux_model.diffusion_model.double_blocks[i]=CopyDSB(flux_model.diffusion_model.double_blocks[i])
    return patches

上面就是有关patch的部分了,就是一个简单的方法,克隆原来层,方便后面apply等操作,确实不清楚为什么要往patch上靠。

那到这差不多所有新的模块的简单过一遍了,额外的不多,大多是为了符合flux框架而作的小改动。

那么到opensora该怎么改呢?

ip_adapter那块基本可以不用动,ImageProjModel那一块还要再测试一下,查询轮换器看上去复杂一点比较难训,这边linear+ln简单,就不知道效果咋样。

先等现在i2v_adapter小样本数据集能不能拟合吧,能拟合了,工作重心就先放到ip_adapter上,但怎么算训好了,还是得找个标准。

您必须 登录 才能发表评论