- #1
btb4198
- 572
- 10
I want to fine-tune the Pixel2Style2Pixel model with my custom data set, but I keep getting an error when I'm trying to load in the pre-train weights. Here is my code :
I am getting this error: AttributeError: 'CustompSp' object has no attribute 'device'
fine-tune the Pixel2Style2Pixel model:
# Load the pre-trained model
os.chdir("/content/pixel2style2pixel")
from models.psp import pSp
config = {
"lr": 0.0001,
"betas": (0.9, 0.999),
"weight_decay": 1e-6,
"stylegan_size": 1024,
"checkpoint_path": MODEL_PATH,
"device": DEVICE,
"input_nc": 3, # Assuming 3 input channels
"output_size": 256, # Add the missing attribute
"encoder_type": 'GradualStyleEncoder', # Add the missing attribute
}
config["checkpoint_path"] = "/content/pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt"
updated_config = config.copy()
updated_config['n_styles'] = 16
config_object = SimpleNamespace(**updated_config)
from pixel2style2pixel.models.psp import pSp
def get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k.startswith(name)}
return d_filt
class CustompSp(pSp):
def __init__(self, opts):
super().__init__(opts)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_weights(self):
ckpt = torch.load(self.opts.checkpoint_path, map_location=self.device)
filtered_state_dict = {k: v for k, v in ckpt['state_dict'].items() if 'encoder' in k}
self.encoder.load_state_dict(filtered_state_dict, strict=False)
decoder_state_dict = {k.replace('module.', ''): v for k, v in ckpt['state_dict'].items() if 'decoder' in k}
self.decoder.load_state_dict(decoder_state_dict, strict=False)
self.load_latent_avg(ckpt)
def load_latent_avg(self, ckpt):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.device)
else:
self.latent_avg = None
model = CustompSp(config_object)
model.train()
model.to(DEVICE)
Full error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-71-36ef46aa58b7> in <cell line: 1>()
----> 1 model = CustompSp(config_object)
2 model.train()
3 model.to(DEVICE)
3 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in __getattr__(self, name)
1612 if name in modules:
1613 return modules[name]
-> 1614 raise AttributeError("'{}' object has no attribute '{}'".format(
1615 type(self).__name__, name))
1616
AttributeError: 'CustompSp' object has no attribute 'device'