/*
 * Licensed to the University Corporation for Advanced Internet Development,
 * Inc. (UCAID) under one or more contributor license agreements.  See the
 * NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The UCAID licenses this file to You under the Apache
 * License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.plugin.authn.totp.impl;

import java.util.Collection;
import java.util.Collections;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.profile.context.ProfileRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.shibboleth.idp.attribute.IdPAttribute;
import net.shibboleth.idp.attribute.StringAttributeValue;
import net.shibboleth.idp.attribute.resolver.AttributeResolver;
import net.shibboleth.idp.attribute.resolver.context.AttributeResolutionContext;
import net.shibboleth.idp.plugin.authn.totp.context.TOTPContext;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullAfterInit;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.annotation.constraint.ThreadSafeAfterInit;
import net.shibboleth.utilities.java.support.codec.Base32Support;
import net.shibboleth.utilities.java.support.codec.Base64Support;
import net.shibboleth.utilities.java.support.codec.DecodingException;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.primitive.StringSupport;
import net.shibboleth.utilities.java.support.service.ReloadableService;

/**
 * Token seed source implementation that leverages the {@link AttributeResolver}.
 */
@ThreadSafeAfterInit
public class AttributeResolverSeedSource extends AbstractSeedSource {

    /** Default attribute ID source. */
    @Nonnull @NotEmpty public static final String DEFAULT_ATTRIBUTE_ID = "tokenSeeds";
        
    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(AttributeResolverSeedSource.class);
    
    /** Attribute resolver service. */
    @NonnullAfterInit private ReloadableService<AttributeResolver> attributeResolver;
    
    /** Attribute ID to resolve. */
    @NonnullAfterInit @NotEmpty private String attributeId;
    
    /**
     * Set the {@link AttributeResolver} to use.
     * 
     * @param service the resolver
     */
    public void setAttributeResolver(@Nonnull final ReloadableService<AttributeResolver> service) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        attributeResolver = Constraint.isNotNull(service, "AttributeResolver cannot be null");
    }
    
    /**
     * Set the source attribute ID to resolve.
     * 
     * @param id attribute ID
     */
    public void setSourceAttribute(@Nonnull @NotEmpty final String id) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        
        attributeId = Constraint.isNotNull(StringSupport.trimOrNull(id), "Source attribute ID cannot be null or empty");
    }
        
    /** {@inheritDoc} */
    @Override
    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        
        if (attributeResolver == null) {
            throw new ComponentInitializationException("AttributeResolver cannot be null");
        } else if (attributeId == null) {
            throw new ComponentInitializationException("Source attribute ID cannot be null");
        }
    }

    /** {@inheritDoc} */
    public void accept(@Nullable final ProfileRequestContext input) {
        ComponentSupport.ifNotInitializedThrowUninitializedComponentException(this);
        
        final TOTPContext totp = getTOTPContextLookupStrategy().apply(input);
        if (totp != null && totp.getUsername() != null) {
            final AttributeResolutionContext resCtx = totp.getSubcontext(AttributeResolutionContext.class, true);

            resCtx.setResolutionLabel("TOTP");
            resCtx.setPrincipal(totp.getUsername());
            resCtx.setRequestedIdPAttributeNames(Collections.singletonList(attributeId));

            log.debug("Resolving attribute {} for '{}'", attributeId, totp.getUsername());

            final Collection<byte[]> seeds = totp.getTokenSeeds();

            try {
                // Resolve the attributes.
                resCtx.resolveAttributes(attributeResolver);
    
                final IdPAttribute attribute = resCtx.getResolvedIdPAttributes().get(attributeId);
                if (attribute != null) {                    
                    attribute.getValues()
                        .stream()
                        .filter(StringAttributeValue.class::isInstance)
                        .map(StringAttributeValue.class::cast)
                        .map(StringAttributeValue::getValue)
                        .forEachOrdered(v -> {
                            try {
                                switch (getEncoding()) {
                                    case BASE32:
                                        seeds.add(Base32Support.decode(v));
                                        break;
                                        
                                    case BASE64:
                                        seeds.add(Base64Support.decode(v));
                                        break;
                                        
                                    default:
                                        throw new DecodingException("Unknown encoding type");
                                }
                            } catch (final DecodingException e) {
                                log.error("Unable to decode seed", e);
                            }
                        });
                }
            } finally {
                totp.removeSubcontext(resCtx);
            }

            log.debug("Resolved {} seed(s) for '{}'", seeds.size(), totp.getUsername());
        } else {
            log.warn("Unable to locate TOTPContext with username set");
        }
    }

}