• 카테고리

    질문 & 답변
  • 세부 분야

    딥러닝 · 머신러닝

  • 해결 여부

    미해결

BaseEstimator 질문드립니다 !

21.07.19 20:48 작성 조회수 548

1

선생님 MyDummyClassifier 에서 BaseEstimator의 역할이 궁금합니다.. 아무리 검색해도 잘모르겠네요 ㅠ

그리고 fit 메서드는 아무것도 학습을 하지 않는데 굳이 학습/검증 세트로 나눈다음 fit(X_train, y_train) 을 수행한 이유도 잘모르겠습니다 ㅠㅠ

답변 1

답변을 작성해보세요.

5

안녕하십니까, 

MyDummyClassifier는 사이킷런의 Classifier 구현을 흉내낸것 입니다. 

사이킷런 프레임웍은 분류를 위한 수행 객체로 Classifier를 가집니다. 가령 DecisionTreeClassifier, RandomForestClassifier등 다양한 분류 알고리즘을 구현한 Classifier를 제공합니다. 보통은 이들 Classifier객체들은 Regressor 객체와 함께 Estimator라고 불립니다. 

이들 Estimator는 사이킷런 프레임웍에서 GridSearchCV, cross_val_score() 등 다양한 Utility class들과 함께 자연스럽게 결합될 수 있는데, 이걸 적용하려면 모든 Estimator들은 BaseEstimator라는 것을 상속 받아야 합니다.  그래서 MyDummyClassifier에서 BaseEstimator를 상속 받았습니다. 

물론 이렇게 BaseEstimator를 상속받지 않고, fit(), predict()를 구현할 수도 있습니다만, 사이킷런 프레임웍의 다른 Estimator 동작방식과 유사한 설명을 드리기 위해서 그렇게 구현한 것입니다. 

그리고 fit() 메서드는 적어주신대로 아무것도 학습하지 않습니다. 그런데 정확도(Accuracy)의 경우는 아무것도 학습하지 않은 상태임에도 좋은 예측 수치를 보여 줄수가 있습니다. 즉 그냥 찍어도 꽤 놓은 수치가 나올수 있는 경우가 바로 정확도 이기 때문에 정확도 수치에는 맹점이 있을 수 있습니다. 

이를 설명 드리기 위해,  Machine Learning의 일반적인 학습 프로세스, 가령 예를 들어 학습과 검증 세트로 나눈 다음에 fit(X_train, y_train)으로 아무것도 아닌 학습을 하고,  predict() 예측을 했을 때, 그렇게 나오면 안됨에도 불구하고 정확도 수치가 높게 나올 수 있는 경우를 보여 드리기 위해서 만든 것입니다.