multiple inputs
320
작성한 질문수 20
예시에서는 독립변수가 하나밖에 존재하지 않은데 2개 이상일때는 call 메서드에 어떻게 적용시키나요? 참고로 from_tensor_slices로 dataset을 만들었을 때 궁금합니다.
아래의 코드는 제가 만든 모델 코드의 일부분 입니다.
다음과 같이 from_tensor_slices로 총 7개의 변수를 가진 tf.dataset을 생성했습니다.
이후 main model의 call 메서드에 인자를 다음과 같이 6개를 지정해서 model을 학습시켰습니다. 이렇게 하는게 맞는지 궁금합니다.
train_ds = tf.data.Dataset.from_tensor_slices((train_user, train_item, train_user_bert, train_item_bert, train_user_roberta, train_item_roberta, train_rating))
def call(self, user, item, user_text_bert,item_text_bert, user_text_roberta, item_text_roberta):
user_vec = self.user_emb(user)
item_vec = self.item_emb(item)
for epoch in range(Epochs):
for a, b, c, d, e, f, g in train_ds:
with tf.GradientTape() as tape:
predictions = model(a, b, c, d, e, f)
loss = loss_object(g, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
답변 1
0
안녕하세요, 인프런 AI 인턴이에요.
여러 개의 독립 변수를 가진 경우, 각 변수마다 모델에 필요한 인자를 전달해야 합니다. 예를 들어, 7개의 변수를 가진 경우, call 메서드에서 각 변수에 대한 인자를 따로 전달해야 합니다.
아래는 예시 코드입니다.
def call(self, user, item, user_text_bert, item_text_bert, user_text_roberta, item_text_roberta):
user_vec = self.user_emb(user)
item_vec = self.item_emb(item)
# 나머지 변수에 대해서도 필요한 전처리 및 연산 수행
# 예시로 다음과 같이 출력해 볼 수 있습니다.
print(user_vec, item_vec, user_text_bert, item_text_bert, user_text_roberta, item_text_roberta)
# 모델의 나머지 부분 구현 코드
위의 코드에서는 user, item 변수는 따로 처리되고, 나머지 변수들도 필요한 작업을 수행한 후에 모델에 전달됩니다.
그리고 학습하는 코드 부분에서도 인자를 잘 전달하셨습니다.
for a, b, c, d, e, f, g in train_ds:
with tf.GradientTape() as tape:
predictions = model(a, b, c, d, e, f)
loss = loss_object(g, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
맞게 구현하셨습니다.
각 변수의 전처리 및 모델에 필요한 연산을 잘 적용하시고 사용하시면 됩니다. 추가적인 질문이 있으시면 언제든지 물어보세요. 즐거운 학습 되세요!
[업로드오류] 강의 내용과 제목 불일치
0
24
2
섹션22 퀴즈 질문
0
30
1
import torch가 안되는 경우는 어떻게 하나요?
0
31
1
소리가 겹쳐서 들려요
0
33
2
20강에서 파인튜닝 때 사용한 데이터가 없어졌습니다. LoRA Trainer 매개변수도 라이브러리 업그레이드로 수정되었습니다.
0
24
1
[개정판] 딥러닝 컴퓨터 비전 완벽 가이드 먼저? 구현하며 배우는 Transformer 먼저?
0
34
1
수업자료
0
29
2
전 강의와 전혀 이어지지가 않음
0
34
1
pytorch local 설치 옵션에 conda 가 없습니다.
0
40
3
pc에서는 괜찮은데 탭으로 들으니 화면확대시 화면이 까맙니다
0
24
1
강의 환경설정 질문
0
43
2
모든 자료 다운로드 누를때마다 똑같은 excel파일이 다운로드 받아짐. 노션 주소 공유되나요?
0
35
2
오토인코더+ Knn, SVC 로 해석하는경우
0
46
3
강의자료에 소스코드가 없는데요
0
51
3
강화학습저장 및 로드
0
61
1
Custom Dataset에서의 polygon 정보 관련
0
90
3
동영상 재생오류
0
55
1
epoch 1부터 loss가 너무 낮게 나와 학습이 안되네요
0
542
1
텍스트 데이터일 때의 dtype
0
267
1
12강 data split take와 skip
0
319
1
<tensorflow사용메뉴얼> 강의파일
0
291
0
SyntaxError: keyword can't be an expression
0
519
0
강의 감사합니다. 다섯번째 강의인 Model Implementation가 재생이 안됩니다.
0
173
0
unsupported operand type(s) for *: 'float' and 'NoneType'
0
1969
1





