FineTune the Pixel2Style2Pixel model with my custom data set

In summary: False) def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, inject_latent=None, return_latents=False, alpha=None): if input_code: codes = x else: codes = self.encoder(x) if latent_mask is not None: for i in latent_mask: if inject_latent is not None: codes[:,i] = inject_latent[:, i] else: codes[:, i] = 0 input_is_latent = not input_code images, result_latent = self.decoder([codes
  • #1
btb4198
572
10
I want to fine-tune the Pixel2Style2Pixel model with my custom data set, but I keep getting an error when I'm trying to load in the pre-train weights. Here is my code :

fine-tune the Pixel2Style2Pixel model:
# Load the pre-trained model
os.chdir("/content/pixel2style2pixel")
from models.psp import pSp

config = {
    "lr": 0.0001,
    "betas": (0.9, 0.999),
    "weight_decay": 1e-6,
    "stylegan_size": 1024,
    "checkpoint_path": MODEL_PATH,
    "device": DEVICE,
    "input_nc": 3,  # Assuming 3 input channels
    "output_size": 256,  # Add the missing attribute
    "encoder_type": 'GradualStyleEncoder',  # Add the missing attribute
}

config["checkpoint_path"] = "/content/pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt"
updated_config = config.copy()
updated_config['n_styles'] = 16
config_object = SimpleNamespace(**updated_config)
from pixel2style2pixel.models.psp import pSp

def get_keys(d, name):
    if 'state_dict' in d:
        d = d['state_dict']
    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k.startswith(name)}
    return d_filt

class CustompSp(pSp):
    def __init__(self, opts):
        super().__init__(opts)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def load_weights(self):
        ckpt = torch.load(self.opts.checkpoint_path, map_location=self.device)

        filtered_state_dict = {k: v for k, v in ckpt['state_dict'].items() if 'encoder' in k}
        self.encoder.load_state_dict(filtered_state_dict, strict=False)

        decoder_state_dict = {k.replace('module.', ''): v for k, v in ckpt['state_dict'].items() if 'decoder' in k}
        self.decoder.load_state_dict(decoder_state_dict, strict=False)

        self.load_latent_avg(ckpt)

    def load_latent_avg(self, ckpt):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.device)
        else:
            self.latent_avg = None

model = CustompSp(config_object)
model.train()
model.to(DEVICE)
I am getting this error: AttributeError: 'CustompSp' object has no attribute 'device'
Full error message:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-71-36ef46aa58b7> in <cell line: 1>()
----> 1 model = CustompSp(config_object)
      2 model.train()
      3 model.to(DEVICE)

3 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in __getattr__(self, name)
   1612             if name in modules:
   1613                 return modules[name]
-> 1614         raise AttributeError("'{}' object has no attribute '{}'".format(
   1615             type(self).__name__, name))
   1616

AttributeError: 'CustompSp' object has no attribute 'device'
 
Technology news on Phys.org
  • #2
Your custompSp class inherits from the pSp class. Does the pSp class (which isn't shown in the code you provided) have a device attribute?
 
  • #3
Sorry, I am posting it now:
psp.py:
"""
This file defines the core research contribution
"""
import matplotlib
matplotlib.use('Agg')
import math

import torch
from torch import nn
from models.encoders import psp_encoders
from models.stylegan2.model import Generator
from configs.paths_config import model_pathsdef get_keys(d, name):
    if 'state_dict' in d:
        d = d['state_dict']
    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
    return d_filtclass pSp(nn.Module):

    def __init__(self, opts):
        super(pSp, self).__init__()
        self.set_opts(opts)
        # compute number of style inputs based on the output resolution
        self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
        # Define architecture
        self.encoder = self.set_encoder()
        self.decoder = Generator(self.opts.output_size, 512, 8)
        self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
        # Load weights if needed
        self.load_weights()

    def set_encoder(self):
        if self.opts.encoder_type == 'GradualStyleEncoder':
            encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
        elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
            encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
        elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
            encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
        else:
            raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
        return encoder

    def load_weights(self):
        if self.opts.checkpoint_path is not None:
            print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
            ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
            self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
            self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
            self.__load_latent_avg(ckpt)
        else:
            print('Loading encoders weights from irse50!')
            encoder_ckpt = torch.load(model_paths['ir_se50'])
            # if input to encoder is not an RGB image, do not load the input layer weights
            if self.opts.label_nc != 0:
                encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
            self.encoder.load_state_dict(encoder_ckpt, strict=False)
            print('Loading decoder weights from pretrained!')
            ckpt = torch.load(self.opts.stylegan_weights)
            self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
            if self.opts.learn_in_w:
                self.__load_latent_avg(ckpt, repeat=1)
            else:
                self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)

    def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
                inject_latent=None, return_latents=False, alpha=None):
        if input_code:
            codes = x
        else:
            codes = self.encoder(x)
            # normalize with respect to the center of an average face
            if self.opts.start_from_latent_avg:
                if self.opts.learn_in_w:
                    codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
                else:
                    codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)        if latent_mask is not None:
            for i in latent_mask:
                if inject_latent is not None:
                    if alpha is not None:
                        codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
                    else:
                        codes[:, i] = inject_latent[:, i]
                else:
                    codes[:, i] = 0

        input_is_latent = not input_code
        images, result_latent = self.decoder([codes],
                                             input_is_latent=input_is_latent,
                                             randomize_noise=randomize_noise,
                                             return_latents=return_latents)

        if resize:
            images = self.face_pool(images)

        if return_latents:
            return images, result_latent
        else:
            return images

    def set_opts(self, opts):
        self.opts = opts

    def __load_latent_avg(self, ckpt, repeat=None):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
            if repeat is not None:
                self.latent_avg = self.latent_avg.repeat(repeat, 1)
        else:
            self.latent_avg = None

You are correct that the pSp class does not have a device attribute. So I modify the CustompSp class by adding a device attribute as follows:

updated CustompSp class:
class CustompSp(pSp):
    def __init__(self, opts):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        super().__init__(opts)

    def load_weights(self):
        ckpt = torch.load(self.opts.checkpoint_path, map_location=self.device)

        filtered_state_dict = {k: v for k, v in ckpt['state_dict'].items() if 'encoder' in k}
        self.encoder.load_state_dict(filtered_state_dict, strict=False)

        decoder_state_dict = {k.replace('module.', ''): v for k, v in ckpt['state_dict'].items() if 'decoder' in k}
        self.decoder.load_state_dict(decoder_state_dict, strict=False)

        self.load_latent_avg(ckpt)

    def load_latent_avg(self, ckpt):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.device)
        else:
            self.latent_avg = None
Anyhow it seemed to work
 
  • #4
btb4198 said:
You are correct that the pSp class does not have a device attribute.
I didn't say that the pSp class was missing a device attribute. I just asked whether the pSp class had such an attribute. If you get an error that such-and-such attribute is missing, an obvious thing to do is to see if that attribute is present in the class or any super class.
btb4198 said:
So I modify the CustompSp class by adding a device attribute as follows:

btb4198 said:
Anyhow it seemed to work
What a surprise...
 

Similar threads

  • Programming and Computer Science
Replies
5
Views
3K
  • Programming and Computer Science
Replies
3
Views
1K
  • Programming and Computer Science
Replies
1
Views
1K
  • Programming and Computer Science
Replies
2
Views
788
  • Programming and Computer Science
Replies
1
Views
2K
  • Engineering and Comp Sci Homework Help
Replies
3
Views
470
Back
Top