diff --git a/src/common/types/value/value.cpp b/src/common/types/value/value.cpp index 9ff18a2ad05..ada46d2b2a7 100644 --- a/src/common/types/value/value.cpp +++ b/src/common/types/value/value.cpp @@ -264,6 +264,26 @@ Value::Value(double val_) : isNull_{false}, childrenSize{0} { val.doubleVal = val_; } +Value::Value(decimal_t val_) : isNull_{false}, childrenSize{0} { + dataType = LogicalType::DECIMAL(val_.precision, val_.scale); + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + val.int16Val = (int16_t)(val_.val); + break; + case PhysicalTypeID::INT32: + val.int32Val = (int32_t)(val_.val); + break; + case PhysicalTypeID::INT64: + val.int64Val = (int64_t)(val_.val); + break; + case PhysicalTypeID::INT128: + val.int128Val = val_.val; + break; + default: + KU_UNREACHABLE; + } +} + Value::Value(date_t val_) : isNull_{false}, childrenSize{0} { dataType = LogicalType::DATE(); val.int32Val = val_.days; diff --git a/src/include/common/types/decimal_t.h b/src/include/common/types/decimal_t.h new file mode 100644 index 00000000000..2f0aee2a83d --- /dev/null +++ b/src/include/common/types/decimal_t.h @@ -0,0 +1,20 @@ +#pragma once + +#include "int128_t.h" + +namespace kuzu { +namespace common { + +struct KUZU_API decimal_t { + + int128_t val = 0; + uint32_t precision = 18; + uint32_t scale = 3; + + decimal_t() {} + decimal_t(int128_t val, uint32_t prec, uint32_t scale) + : val(val), precision(prec), scale(scale) {} +}; + +} // namespace common +} // namespace kuzu diff --git a/src/include/common/types/value/value.h b/src/include/common/types/value/value.h index 46fa717e2bc..bdb9b05ed25 100644 --- a/src/include/common/types/value/value.h +++ b/src/include/common/types/value/value.h @@ -4,6 +4,7 @@ #include "common/api.h" #include "common/types/date_t.h" +#include "common/types/decimal_t.h" #include "common/types/int128_t.h" #include "common/types/internal_id_t.h" #include "common/types/interval_t.h" @@ -99,6 +100,10 @@ class Value { * @param val_ the float value to set. */ KUZU_API explicit Value(float val_); + /** + * @param val_ the decimal_t value to set + */ + KUZU_API explicit Value(decimal_t val_); /** * @param val_ the date value to set. */ diff --git a/src/parser/transform/transform_expression.cpp b/src/parser/transform/transform_expression.cpp index ad557ca3203..cfe21179a53 100644 --- a/src/parser/transform/transform_expression.cpp +++ b/src/parser/transform/transform_expression.cpp @@ -1,6 +1,7 @@ #include "function/aggregate/count_star.h" #include "function/arithmetic/vector_arithmetic_functions.h" #include "function/cast/functions/cast_from_string_functions.h" +#include "function/cast/functions/cast_string_non_nested_functions.h" #include "function/list/vector_list_functions.h" #include "function/string/vector_string_functions.h" #include "function/struct/vector_struct_functions.h" @@ -641,10 +642,27 @@ std::unique_ptr Transformer::transformIntegerLiteral( std::unique_ptr Transformer::transformDoubleLiteral( CypherParser::OC_DoubleLiteralContext& ctx) { auto text = ctx.RegularDecimalReal()->getText(); - ku_string_t literal{text.c_str(), text.length()}; - double result; - function::CastString::operation(literal, result); - return std::make_unique(Value(result), ctx.getText()); + if (text[0] == '-') { + text.erase(text.begin()); + } + auto type = LogicalType::DOUBLE(); + if (text.size() - 1 <= DECIMAL_PRECISION_LIMIT) { + auto decimalPoint = text.find('.'); + KU_ASSERT(decimalPoint != std::string::npos); + type = LogicalType::DECIMAL(text.size() - 1, text.size() - decimalPoint - 1); + } + text = ctx.RegularDecimalReal()->getText(); // undo changes + if (type.getLogicalTypeID() == LogicalTypeID::DECIMAL) { + int128_t val; + decimalCast(text.c_str(), text.length(), val, type); + decimal_t result(val, DecimalType::getPrecision(type), DecimalType::getScale(type)); + return std::make_unique(Value(result), ctx.getText()); + } else { + ku_string_t literal{text.c_str(), text.length()}; + double result; + function::CastString::operation(literal, result); + return std::make_unique(Value(result), ctx.getText()); + } } } // namespace parser