package org.jpmml.sparkml;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import org.apache.spark.sql.catalyst.expressions.Abs;
import org.apache.spark.sql.catalyst.expressions.Acos;
import org.apache.spark.sql.catalyst.expressions.Add;
import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.And;
import org.apache.spark.sql.catalyst.expressions.Asin;
import org.apache.spark.sql.catalyst.expressions.Atan;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic;
import org.apache.spark.sql.catalyst.expressions.BinaryComparison;
import org.apache.spark.sql.catalyst.expressions.BinaryMathExpression;
import org.apache.spark.sql.catalyst.expressions.BinaryOperator;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast;
import org.apache.spark.sql.catalyst.expressions.Ceil;
import org.apache.spark.sql.catalyst.expressions.Concat;
import org.apache.spark.sql.catalyst.expressions.Cos;
import org.apache.spark.sql.catalyst.expressions.Cosh;
import org.apache.spark.sql.catalyst.expressions.Divide;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Exp;
import org.apache.spark.sql.catalyst.expressions.Expm1;
import org.apache.spark.sql.catalyst.expressions.Floor;
import org.apache.spark.sql.catalyst.expressions.GreaterThan;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.Greatest;
import org.apache.spark.sql.catalyst.expressions.Hypot;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.In;
import org.apache.spark.sql.catalyst.expressions.IsNaN;
import org.apache.spark.sql.catalyst.expressions.IsNotNull;
import org.apache.spark.sql.catalyst.expressions.IsNull;
import org.apache.spark.sql.catalyst.expressions.Least;
import org.apache.spark.sql.catalyst.expressions.Length;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Log;
import org.apache.spark.sql.catalyst.expressions.Log10;
import org.apache.spark.sql.catalyst.expressions.Log1p;
import org.apache.spark.sql.catalyst.expressions.Lower;
import org.apache.spark.sql.catalyst.expressions.Multiply;
import org.apache.spark.sql.catalyst.expressions.Not;
import org.apache.spark.sql.catalyst.expressions.Or;
import org.apache.spark.sql.catalyst.expressions.Pow;
import org.apache.spark.sql.catalyst.expressions.RLike;
import org.apache.spark.sql.catalyst.expressions.RegExpReplace;
import org.apache.spark.sql.catalyst.expressions.Rint;
import org.apache.spark.sql.catalyst.expressions.Sin;
import org.apache.spark.sql.catalyst.expressions.Sinh;
import org.apache.spark.sql.catalyst.expressions.Sqrt;
import org.apache.spark.sql.catalyst.expressions.StringReplace;
import org.apache.spark.sql.catalyst.expressions.StringTrim;
import org.apache.spark.sql.catalyst.expressions.Substring;
import org.apache.spark.sql.catalyst.expressions.Subtract;
import org.apache.spark.sql.catalyst.expressions.Tan;
import org.apache.spark.sql.catalyst.expressions.Tanh;
import org.apache.spark.sql.catalyst.expressions.UnaryExpression;
import org.apache.spark.sql.catalyst.expressions.UnaryMinus;
import org.apache.spark.sql.catalyst.expressions.UnaryPositive;
import org.apache.spark.sql.catalyst.expressions.Upper;
import org.apache.spark.sql.types.Decimal;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasDataType;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IfElseBuilder;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.ExpressionCompactor;
import scala.Option;
import scala.Tuple2;
import scala.collection.JavaConversions;

/* loaded from: input_file:org/jpmml/sparkml/ExpressionTranslator.class */
public class ExpressionTranslator {
    private SparkMLEncoder encoder = null;
    private static final Package javaLangPackage = Package.getPackage("java.lang");
    private static final int MAX_STRING_LENGTH = 65536;

    private ExpressionTranslator(SparkMLEncoder sparkMLEncoder) {
        setEncoder(sparkMLEncoder);
    }

    public SparkMLEncoder getEncoder() {
        return this.encoder;
    }

    private void setEncoder(SparkMLEncoder sparkMLEncoder) {
        this.encoder = (SparkMLEncoder) Objects.requireNonNull(sparkMLEncoder);
    }

    public static Expression translate(SparkMLEncoder sparkMLEncoder, org.apache.spark.sql.catalyst.expressions.Expression expression) {
        return translate(sparkMLEncoder, expression, true);
    }

    public static Expression translate(SparkMLEncoder sparkMLEncoder, org.apache.spark.sql.catalyst.expressions.Expression expression, boolean z) {
        Expression translateInternal = new ExpressionTranslator(sparkMLEncoder).translateInternal(expression);
        if (z) {
            new ExpressionCompactor().applyTo(translateInternal);
        }
        return translateInternal;
    }

    private Expression translateInternal(org.apache.spark.sql.catalyst.expressions.Expression expression) {
        DataType translateDataType;
        Object simpleObject;
        String str;
        String str2;
        SparkMLEncoder encoder = getEncoder();
        if (expression instanceof Alias) {
            Alias alias = (Alias) expression;
            return new AliasExpression(alias.name(), translateInternal(alias.child()));
        }
        if (expression instanceof AttributeReference) {
            return new FieldRef(((AttributeReference) expression).name());
        }
        if (expression instanceof BinaryMathExpression) {
            BinaryMathExpression binaryMathExpression = (BinaryMathExpression) expression;
            org.apache.spark.sql.catalyst.expressions.Expression expression2 = (org.apache.spark.sql.catalyst.expressions.Expression) binaryMathExpression.left();
            org.apache.spark.sql.catalyst.expressions.Expression expression3 = (org.apache.spark.sql.catalyst.expressions.Expression) binaryMathExpression.right();
            if (binaryMathExpression instanceof Hypot) {
                str2 = "hypot";
            } else {
                if (!(binaryMathExpression instanceof Pow)) {
                    throw new IllegalArgumentException(formatMessage(binaryMathExpression));
                }
                str2 = "pow";
            }
            return ExpressionUtil.createApply(str2, new Expression[]{translateInternal(expression2), translateInternal(expression3)});
        }
        if (expression instanceof BinaryOperator) {
            BinaryComparison binaryComparison = (BinaryOperator) expression;
            String symbol = binaryComparison.symbol();
            org.apache.spark.sql.catalyst.expressions.Expression expression4 = (org.apache.spark.sql.catalyst.expressions.Expression) binaryComparison.left();
            org.apache.spark.sql.catalyst.expressions.Expression expression5 = (org.apache.spark.sql.catalyst.expressions.Expression) binaryComparison.right();
            if ((expression instanceof And) || (expression instanceof Or)) {
                boolean z = -1;
                switch (symbol.hashCode()) {
                    case 1216:
                        if (symbol.equals("&&")) {
                            z = false;
                            break;
                        }
                        break;
                    case 3968:
                        if (symbol.equals("||")) {
                            z = true;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        str = "and";
                        break;
                    case true:
                        str = "or";
                        break;
                    default:
                        throw new IllegalArgumentException(formatMessage(binaryComparison));
                }
            } else if ((expression instanceof Add) || (expression instanceof Divide) || (expression instanceof Multiply) || (expression instanceof Subtract)) {
                BinaryArithmetic binaryArithmetic = (BinaryArithmetic) binaryComparison;
                boolean z2 = -1;
                switch (symbol.hashCode()) {
                    case 42:
                        if (symbol.equals("*")) {
                            z2 = 2;
                            break;
                        }
                        break;
                    case 43:
                        if (symbol.equals("+")) {
                            z2 = false;
                            break;
                        }
                        break;
                    case 45:
                        if (symbol.equals("-")) {
                            z2 = 3;
                            break;
                        }
                        break;
                    case 47:
                        if (symbol.equals("/")) {
                            z2 = true;
                            break;
                        }
                        break;
                }
                switch (z2) {
                    case false:
                        str = "+";
                        break;
                    case true:
                        str = "/";
                        break;
                    case true:
                        str = "*";
                        break;
                    case true:
                        str = "-";
                        break;
                    default:
                        throw new IllegalArgumentException(formatMessage(binaryArithmetic));
                }
            } else {
                if (!(expression instanceof EqualTo) && !(expression instanceof GreaterThan) && !(expression instanceof GreaterThanOrEqual) && !(expression instanceof LessThan) && !(expression instanceof LessThanOrEqual)) {
                    throw new IllegalArgumentException(formatMessage(binaryComparison));
                }
                BinaryComparison binaryComparison2 = binaryComparison;
                boolean z3 = -1;
                switch (symbol.hashCode()) {
                    case 60:
                        if (symbol.equals("<")) {
                            z3 = 3;
                            break;
                        }
                        break;
                    case 61:
                        if (symbol.equals("=")) {
                            z3 = false;
                            break;
                        }
                        break;
                    case 62:
                        if (symbol.equals(">")) {
                            z3 = true;
                            break;
                        }
                        break;
                    case 1921:
                        if (symbol.equals("<=")) {
                            z3 = 4;
                            break;
                        }
                        break;
                    case 1983:
                        if (symbol.equals(">=")) {
                            z3 = 2;
                            break;
                        }
                        break;
                }
                switch (z3) {
                    case false:
                        str = "equal";
                        break;
                    case true:
                        str = "greaterThan";
                        break;
                    case true:
                        str = "greaterOrEqual";
                        break;
                    case true:
                        str = "lessThan";
                        break;
                    case true:
                        str = "lessOrEqual";
                        break;
                    default:
                        throw new IllegalArgumentException(formatMessage(binaryComparison2));
                }
            }
            return ExpressionUtil.createApply(str, new Expression[]{translateInternal(expression4), translateInternal(expression5)});
        }
        if (expression instanceof CaseWhen) {
            CaseWhen caseWhen = (CaseWhen) expression;
            List seqAsJavaList = JavaConversions.seqAsJavaList(caseWhen.branches());
            Option elseValue = caseWhen.elseValue();
            IfElseBuilder ifElseBuilder = new IfElseBuilder();
            Iterator it = seqAsJavaList.iterator();
            do {
                Tuple2 tuple2 = (Tuple2) it.next();
                ifElseBuilder.add(translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) tuple2._1()), translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) tuple2._2()));
            } while (it.hasNext());
            if (elseValue.isDefined()) {
                ifElseBuilder.terminate(translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) elseValue.get()));
            }
            return ifElseBuilder.build();
        }
        if (expression instanceof Cast) {
            Cast cast = (Cast) expression;
            org.apache.spark.sql.catalyst.expressions.Expression child = cast.child();
            HasDataType translateInternal = translateInternal(child);
            DataType translateDataType2 = DatasetUtil.translateDataType(cast.dataType());
            if (!(translateInternal instanceof HasDataType)) {
                return new FieldRef(encoder.createDerivedField(translateInternal instanceof AliasExpression ? ((AliasExpression) translateInternal).getName() : FieldNameUtil.create(translateDataType2, new Object[]{String.valueOf(child)}), TypeUtil.getOpType(translateDataType2), translateDataType2, AliasExpression.unwrap(translateInternal)));
            }
            translateInternal.setDataType(translateDataType2);
            return translateInternal;
        }
        if (expression instanceof Concat) {
            List seqAsJavaList2 = JavaConversions.seqAsJavaList(((Concat) expression).children());
            Apply createApply = ExpressionUtil.createApply("concat", new Expression[0]);
            Iterator it2 = seqAsJavaList2.iterator();
            while (it2.hasNext()) {
                createApply.addExpressions(new Expression[]{translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) it2.next())});
            }
            return createApply;
        }
        if (expression instanceof Greatest) {
            List seqAsJavaList3 = JavaConversions.seqAsJavaList(((Greatest) expression).children());
            Apply createApply2 = ExpressionUtil.createApply("max", new Expression[0]);
            Iterator it3 = seqAsJavaList3.iterator();
            while (it3.hasNext()) {
                createApply2.addExpressions(new Expression[]{translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) it3.next())});
            }
            return createApply2;
        }
        if (expression instanceof If) {
            If r0 = (If) expression;
            return ExpressionUtil.createApply("if", new Expression[]{translateInternal(r0.predicate()), translateInternal(r0.trueValue()), translateInternal(r0.falseValue())});
        }
        if (expression instanceof In) {
            In in = (In) expression;
            org.apache.spark.sql.catalyst.expressions.Expression value = in.value();
            List seqAsJavaList4 = JavaConversions.seqAsJavaList(in.list());
            Apply createApply3 = ExpressionUtil.createApply("isIn", new Expression[]{translateInternal(value)});
            Iterator it4 = seqAsJavaList4.iterator();
            while (it4.hasNext()) {
                createApply3.addExpressions(new Expression[]{translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) it4.next())});
            }
            return createApply3;
        }
        if (expression instanceof Least) {
            List seqAsJavaList5 = JavaConversions.seqAsJavaList(((Least) expression).children());
            Apply createApply4 = ExpressionUtil.createApply("min", new Expression[0]);
            Iterator it5 = seqAsJavaList5.iterator();
            while (it5.hasNext()) {
                createApply4.addExpressions(new Expression[]{translateInternal((org.apache.spark.sql.catalyst.expressions.Expression) it5.next())});
            }
            return createApply4;
        }
        if (expression instanceof Length) {
            return ExpressionUtil.createApply("stringLength", new Expression[]{translateInternal(((Length) expression).child())});
        }
        if (expression instanceof Literal) {
            Literal literal = (Literal) expression;
            Object value2 = literal.value();
            if (value2 == null) {
                return ExpressionUtil.createMissingConstant();
            }
            if (value2 instanceof Decimal) {
                translateDataType = DataType.STRING;
                simpleObject = ((Decimal) value2).toString();
            } else {
                translateDataType = DatasetUtil.translateDataType(literal.dataType());
                simpleObject = toSimpleObject(value2);
            }
            return ExpressionUtil.createConstant(translateDataType, simpleObject);
        }
        if (expression instanceof RegExpReplace) {
            RegExpReplace regExpReplace = (RegExpReplace) expression;
            return ExpressionUtil.createApply("replace", new Expression[]{translateInternal(regExpReplace.subject()), translateInternal(regExpReplace.regexp()), translateInternal(regExpReplace.rep())});
        }
        if (expression instanceof RLike) {
            RLike rLike = (RLike) expression;
            return ExpressionUtil.createApply("matches", new Expression[]{translateInternal(rLike.left()), translateInternal(rLike.right())});
        }
        if (expression instanceof StringReplace) {
            StringReplace stringReplace = (StringReplace) expression;
            return ExpressionUtil.createApply("replace", new Expression[]{translateInternal(stringReplace.srcExpr()), transformString(translateInternal(stringReplace.searchExpr()), ExpressionTranslator::escapeSearchString), transformString(translateInternal(stringReplace.replaceExpr()), ExpressionTranslator::escapeReplacementString)});
        }
        if (expression instanceof StringTrim) {
            StringTrim stringTrim = (StringTrim) expression;
            org.apache.spark.sql.catalyst.expressions.Expression srcStr = stringTrim.srcStr();
            if (stringTrim.trimStr().isDefined()) {
                throw new IllegalArgumentException();
            }
            return ExpressionUtil.createApply("trimBlanks", new Expression[]{translateInternal(srcStr)});
        }
        if (expression instanceof Substring) {
            Substring substring = (Substring) expression;
            org.apache.spark.sql.catalyst.expressions.Expression str3 = substring.str();
            Literal pos = substring.pos();
            Literal len = substring.len();
            int asInt = ValueUtil.asInt((Number) pos.value());
            if (asInt <= 0) {
                throw new IllegalArgumentException("Expected absolute start position, got relative start position " + pos);
            }
            return ExpressionUtil.createApply("substring", new Expression[]{translateInternal(str3), ExpressionUtil.createConstant(Integer.valueOf(asInt)), ExpressionUtil.createConstant(Integer.valueOf(Math.min(ValueUtil.asInt((Number) len.value()), MAX_STRING_LENGTH)))});
        }
        if (!(expression instanceof UnaryExpression)) {
            throw new IllegalArgumentException(formatMessage(expression));
        }
        UnaryExpression unaryExpression = (UnaryExpression) expression;
        org.apache.spark.sql.catalyst.expressions.Expression expression6 = (org.apache.spark.sql.catalyst.expressions.Expression) unaryExpression.child();
        if (expression instanceof Abs) {
            return ExpressionUtil.createApply("abs", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Acos) {
            return ExpressionUtil.createApply("acos", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Asin) {
            return ExpressionUtil.createApply("asin", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Atan) {
            return ExpressionUtil.createApply("atan", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Ceil) {
            return ExpressionUtil.createApply("ceil", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Cos) {
            return ExpressionUtil.createApply("cos", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Cosh) {
            return ExpressionUtil.createApply("cosh", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Exp) {
            return ExpressionUtil.createApply("exp", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Expm1) {
            return ExpressionUtil.createApply("expm1", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Floor) {
            return ExpressionUtil.createApply("floor", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Log) {
            return ExpressionUtil.createApply("ln", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Log10) {
            return ExpressionUtil.createApply("log10", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Log1p) {
            return ExpressionUtil.createApply("ln1p", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Lower) {
            return ExpressionUtil.createApply("lowercase", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof IsNaN) {
            return ExpressionUtil.createApply("isNotValid", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof IsNotNull) {
            return ExpressionUtil.createApply("isNotMissing", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof IsNull) {
            return ExpressionUtil.createApply("isMissing", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Not) {
            return ExpressionUtil.createApply("not", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Rint) {
            return ExpressionUtil.createApply("rint", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Sin) {
            return ExpressionUtil.createApply("sin", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Sinh) {
            return ExpressionUtil.createApply("sinh", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Sqrt) {
            return ExpressionUtil.createApply("sqrt", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Tan) {
            return ExpressionUtil.createApply("tan", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof Tanh) {
            return ExpressionUtil.createApply("tanh", new Expression[]{translateInternal(expression6)});
        }
        if (expression instanceof UnaryMinus) {
            return ExpressionUtil.toNegative(translateInternal(expression6));
        }
        if (expression instanceof UnaryPositive) {
            return translateInternal(expression6);
        }
        if (expression instanceof Upper) {
            return ExpressionUtil.createApply("uppercase", new Expression[]{translateInternal(expression6)});
        }
        throw new IllegalArgumentException(formatMessage(unaryExpression));
    }

    private static String escapeSearchString(String str) {
        return escape(str, "<([{\\^-=$!|]})?*+.>");
    }

    private static String escapeReplacementString(String str) {
        return escape(str, "\\$");
    }

    private static String escape(String str, String str2) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < str.length(); i++) {
            char charAt = str.charAt(i);
            if (str2.indexOf(charAt) > -1) {
                sb.append('\\');
            }
            sb.append(charAt);
        }
        return sb.toString();
    }

    private static Constant transformString(Expression expression, Function<String, String> function) {
        Constant constant = (Constant) expression;
        if (constant.getDataType() != DataType.STRING) {
            throw new IllegalArgumentException();
        }
        constant.setValue(function.apply((String) constant.getValue()));
        return constant;
    }

    private static Object toSimpleObject(Object obj) {
        return !javaLangPackage.equals(obj.getClass().getPackage()) ? obj.toString() : obj;
    }

    private static String formatMessage(org.apache.spark.sql.catalyst.expressions.Expression expression) {
        if (expression == null) {
            return null;
        }
        return "Spark SQL function '" + String.valueOf(expression) + "' (class " + expression.getClass().getName() + ") is not supported";
    }
}
