FineTune the Pixel2Style2Pixel model with my custom data set

In summary: False) def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, inject_latent=None, return_latents=False, alpha=None): if input_code: codes = x else: codes = self.encoder(x) if latent_mask is not None: for i in latent_mask: if inject_latent is not None: codes[:,i] = inject_latent[:, i] else: codes[:, i] = 0 input_is_latent = not input_code images, result_latent = self.decoder([codes
  • #1
btb4198
572
10
I want to fine-tune the Pixel2Style2Pixel model with my custom data set, but I keep getting an error when I'm trying to load in the pre-train weights. Here is my code :

fine-tune the Pixel2Style2Pixel model:
# Load the pre-trained model
os.chdir("/content/pixel2style2pixel")
from models.psp import pSp

config = {
    "lr": 0.0001,
    "betas": (0.9, 0.999),
    "weight_decay": 1e-6,
    "stylegan_size": 1024,
    "checkpoint_path": MODEL_PATH,
    "device": DEVICE,
    "input_nc": 3,  # Assuming 3 input channels
    "output_size": 256,  # Add the missing attribute
    "encoder_type": 'GradualStyleEncoder',  # Add the missing attribute
}

config["checkpoint_path"] = "/content/pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt"
updated_config = config.copy()
updated_config['n_styles'] = 16
config_object = SimpleNamespace(**updated_config)
from pixel2style2pixel.models.psp import pSp

def get_keys(d, name):
    if 'state_dict' in d:
        d = d['state_dict']
    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k.startswith(name)}
    return d_filt

class CustompSp(pSp):
    def __init__(self, opts):
        super().__init__(opts)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def load_weights(self):
        ckpt = torch.load(self.opts.checkpoint_path, map_location=self.device)

        filtered_state_dict = {k: v for k, v in ckpt['state_dict'].items() if 'encoder' in k}
        self.encoder.load_state_dict(filtered_state_dict, strict=False)

        decoder_state_dict = {k.replace('module.', ''): v for k, v in ckpt['state_dict'].items() if 'decoder' in k}
        self.decoder.load_state_dict(decoder_state_dict, strict=False)

        self.load_latent_avg(ckpt)

    def load_latent_avg(self, ckpt):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.device)
        else:
            self.latent_avg = None

model = CustompSp(config_object)
model.train()
model.to(DEVICE)
I am getting this error: AttributeError: 'CustompSp' object has no attribute 'device'
Full error message:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-71-36ef46aa58b7> in <cell line: 1>()
----> 1 model = CustompSp(config_object)
      2 model.train()
      3 model.to(DEVICE)

3 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in __getattr__(self, name)
   1612             if name in modules:
   1613                 return modules[name]
-> 1614         raise AttributeError("'{}' object has no attribute '{}'".format(
   1615             type(self).__name__, name))
   1616

AttributeError: 'CustompSp' object has no attribute 'device'
 
Technology news on Phys.org
  • #2
Your custompSp class inherits from the pSp class. Does the pSp class (which isn't shown in the code you provided) have a device attribute?
 
  • #3
Sorry, I am posting it now:
psp.py:
"""
This file defines the core research contribution
"""
import matplotlib
matplotlib.use('Agg')
import math

import torch
from torch import nn
from models.encoders import psp_encoders
from models.stylegan2.model import Generator
from configs.paths_config import model_pathsdef get_keys(d, name):
    if 'state_dict' in d:
        d = d['state_dict']
    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
    return d_filtclass pSp(nn.Module):

    def __init__(self, opts):
        super(pSp, self).__init__()
        self.set_opts(opts)
        # compute number of style inputs based on the output resolution
        self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
        # Define architecture
        self.encoder = self.set_encoder()
        self.decoder = Generator(self.opts.output_size, 512, 8)
        self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
        # Load weights if needed
        self.load_weights()

    def set_encoder(self):
        if self.opts.encoder_type == 'GradualStyleEncoder':
            encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
        elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
            encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
        elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
            encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
        else:
            raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
        return encoder

    def load_weights(self):
        if self.opts.checkpoint_path is not None:
            print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
            ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
            self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
            self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
            self.__load_latent_avg(ckpt)
        else:
            print('Loading encoders weights from irse50!')
            encoder_ckpt = torch.load(model_paths['ir_se50'])
            # if input to encoder is not an RGB image, do not load the input layer weights
            if self.opts.label_nc != 0:
                encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
            self.encoder.load_state_dict(encoder_ckpt, strict=False)
            print('Loading decoder weights from pretrained!')
            ckpt = torch.load(self.opts.stylegan_weights)
            self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
            if self.opts.learn_in_w:
                self.__load_latent_avg(ckpt, repeat=1)
            else:
                self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)

    def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
                inject_latent=None, return_latents=False, alpha=None):
        if input_code:
            codes = x
        else:
            codes = self.encoder(x)
            # normalize with respect to the center of an average face
            if self.opts.start_from_latent_avg:
                if self.opts.learn_in_w:
                    codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
                else:
                    codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)        if latent_mask is not None:
            for i in latent_mask:
                if inject_latent is not None:
                    if alpha is not None:
                        codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
                    else:
                        codes[:, i] = inject_latent[:, i]
                else:
                    codes[:, i] = 0

        input_is_latent = not input_code
        images, result_latent = self.decoder([codes],
                                             input_is_latent=input_is_latent,
                                             randomize_noise=randomize_noise,
                                             return_latents=return_latents)

        if resize:
            images = self.face_pool(images)

        if return_latents:
            return images, result_latent
        else:
            return images

    def set_opts(self, opts):
        self.opts = opts

    def __load_latent_avg(self, ckpt, repeat=None):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
            if repeat is not None:
                self.latent_avg = self.latent_avg.repeat(repeat, 1)
        else:
            self.latent_avg = None

You are correct that the pSp class does not have a device attribute. So I modify the CustompSp class by adding a device attribute as follows:

updated CustompSp class:
class CustompSp(pSp):
    def __init__(self, opts):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        super().__init__(opts)

    def load_weights(self):
        ckpt = torch.load(self.opts.checkpoint_path, map_location=self.device)

        filtered_state_dict = {k: v for k, v in ckpt['state_dict'].items() if 'encoder' in k}
        self.encoder.load_state_dict(filtered_state_dict, strict=False)

        decoder_state_dict = {k.replace('module.', ''): v for k, v in ckpt['state_dict'].items() if 'decoder' in k}
        self.decoder.load_state_dict(decoder_state_dict, strict=False)

        self.load_latent_avg(ckpt)

    def load_latent_avg(self, ckpt):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.device)
        else:
            self.latent_avg = None
Anyhow it seemed to work
 
  • #4
btb4198 said:
You are correct that the pSp class does not have a device attribute.
I didn't say that the pSp class was missing a device attribute. I just asked whether the pSp class had such an attribute. If you get an error that such-and-such attribute is missing, an obvious thing to do is to see if that attribute is present in the class or any super class.
btb4198 said:
So I modify the CustompSp class by adding a device attribute as follows:

btb4198 said:
Anyhow it seemed to work
What a surprise...
 

Related to FineTune the Pixel2Style2Pixel model with my custom data set

1. How do I fine-tune the Pixel2Style2Pixel model with my custom data set?

To fine-tune the Pixel2Style2Pixel model with your custom data set, you will need to first prepare your data by organizing it in the required format. Then, you can use the provided training script and specify your custom data set during training. Be sure to adjust any hyperparameters as needed for your specific data set.

2. What is the process for training the Pixel2Style2Pixel model with custom data?

The process for training the Pixel2Style2Pixel model with custom data involves loading your data set, initializing the model with pre-trained weights, and then running the training script with your custom data set. It is important to monitor the training process, adjust hyperparameters if necessary, and evaluate the model's performance on your custom data set.

3. Can I use transfer learning to fine-tune the Pixel2Style2Pixel model with my custom data set?

Yes, you can use transfer learning to fine-tune the Pixel2Style2Pixel model with your custom data set. By leveraging the pre-trained weights of the model on a large dataset, you can adapt the model to your specific data set with fewer training iterations. This can help improve the model's performance on your custom data set.

4. What are some best practices for fine-tuning the Pixel2Style2Pixel model with custom data?

Some best practices for fine-tuning the Pixel2Style2Pixel model with custom data include properly preparing and preprocessing your data, selecting an appropriate learning rate, monitoring the training process closely, and evaluating the model's performance on a validation set. It is also recommended to experiment with different hyperparameters and training strategies to optimize the model's performance.

5. How can I evaluate the performance of the fine-tuned Pixel2Style2Pixel model on my custom data set?

You can evaluate the performance of the fine-tuned Pixel2Style2Pixel model on your custom data set by using metrics such as accuracy, loss, or other relevant evaluation metrics for your specific task. Additionally, you can visually inspect the model's outputs on sample images from your custom data set to assess its performance qualitatively. It is important to compare the model's performance on both training and validation data to ensure its generalization ability.

Similar threads

  • Programming and Computer Science
Replies
9
Views
2K
  • Programming and Computer Science
Replies
5
Views
3K
  • Programming and Computer Science
Replies
1
Views
1K
  • Programming and Computer Science
Replies
3
Views
2K
  • Programming and Computer Science
Replies
2
Views
897
  • Programming and Computer Science
Replies
1
Views
2K
  • Engineering and Comp Sci Homework Help
Replies
3
Views
1K
  • Programming and Computer Science
Replies
1
Views
3K
Replies
8
Views
3K
Back
Top